DELE ST1504 CA2 Part A: Generative Adversarial Network
Objective:
Develop a Generative Adversarial Network (GAN) model for image generation, utilizing the CIFAR10 dataset. The model aims to generate 1000 high-quality, small color images in 10 distinct classes, showcasing its ability to learn and replicate complex visual patterns.
Background:
GANs are a revolutionary class of artificial neural networks used in unsupervised machine learning tasks. They consist of two parts: a Generator, which creates images, and a Discriminator, which evaluates them. The objective is to train a GAN that excels in producing diverse, realistic images that closely mimic the characteristics of the CIFAR10 dataset.
Key Features:
Implement and evaluate different GAN architectures to determine the most effective model for the CIFAR10 specific image generation task, which should generate images that not only are visually appealing and realistic but also display a wide range of creativity within the constraints of the 10 classes in the dataset.
Output Specification:
The model will produce images that are evaluated based on their similarity to the real images in the CIFAR10 dataset and their diversity across the dataset's classes. The performance of the GAN will be a crucial indicator of its effectiveness in learning and replicating complex patterns from a given dataset.
import gc
import os
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from tensorflow.keras import Model
from skimage.transform import resize
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.losses import BinaryCrossentropy, Hinge
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.regularizers import l1, l2, l1_l2
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.layers import Dense, Reshape, UpSampling2D, Conv2D, BatchNormalization, LeakyReLU, ZeroPadding2D, Dropout, Flatten, Input, Activation, GlobalMaxPooling2D, Conv2DTranspose, PReLU, Embedding, Concatenate
from tensorflow.keras.metrics import Mean
#from tensorflow_addons.layers import SpectralNormalization
import GAN_function as gnnf
from warnings import simplefilter
simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=FutureWarning)
# Fix random seed for reproducibility
seed = 1
np.random.seed(seed)
tf.random.set_seed(seed)
tf.keras.utils.set_random_seed(0)
# Check GPU is available
gpus = tf.config.experimental.list_physical_devices('GPU')
# Memory control: Prevent tensorflow from allocating totality of GPU memory
for gpu in gpus:
try:
print(tf.config.experimental.get_device_details(gpu))
except:
pass
tf.config.experimental.set_memory_growth(gpu, True)
print(f"There are {len(gpus)} GPU(s) present.")
{'device_name': 'NVIDIA GeForce RTX 3060', 'compute_capability': (8, 6)}
There are 1 GPU(s) present.
CIFAR10 Dataset:
Images:
Classes:
Batches:
x_train: uint8 NumPy array of grayscale image data with shapes (50000, 32, 32, 3), containing the training data. Pixel values range from 0 to 255.
y_train: uint8 NumPy array of labels (integers in range 0-9) with shape (50000, 1) for the training data.
x_test: uint8 NumPy array of grayscale image data with shapes (10000, 32, 32, 3), containing the test data. Pixel values range from 0 to 255.
y_test: uint8 NumPy array of labels (integers in range 0-9) with shape (10000, 1) for the test data.
# Load CIFAR10 Dataset
cifar10 = tf.keras.datasets.cifar10.load_data()
# Load CIFAR-10 dataset
(X_train, y_train), (X_test, y_test) = cifar10
# Combine Train and Test datasets
# X_train = np.concatenate((X_train, X_test), axis=0)
# y_train = np.concatenate((y_train, y_test), axis=0)
# For EDA Purposes
eda_data = X_train
# Print the shapes of the combined datasets
print(f"Shape of combined X (features): {X_train.shape}")
print(f"Shape of combined y (labels): {y_train.shape}")
Shape of combined X (features): (50000, 32, 32, 3) Shape of combined y (labels): (50000, 1)
Define Class Labels
# Map integer class labels to their corresponding class names
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
Image Pixel Distribution
# Obtain Pixel Statistics
min = np.min(eda_data, axis=(0, 1, 2))
max = np.max(eda_data, axis=(0, 1, 2))
mean = np.mean(eda_data, axis=(0, 1, 2))
std = np.std(eda_data, axis=(0, 1, 2))
# Print Statistics
print("\nPixel Statistics for the original train dataset:")
print(f"Minimum pixel value: {min}")
print(f"Maximum pixel value: {max}")
print(f"Mean pixel value: {mean}")
print(f"Standard deviation of pixel values: {std}")
Pixel Statistics for the original train dataset: Minimum pixel value: [0 0 0] Maximum pixel value: [255 255 255] Mean pixel value: [125.30691805 122.95039414 113.86538318] Standard deviation of pixel values: [62.99321928 62.08870764 66.70489964]
Class Distribution
classes, counts = np.unique(y_train, return_counts=True)
class_count_dict = dict(zip(class_names, counts))
df = pd.DataFrame({'Count': class_count_dict})
df
| Count | |
|---|---|
| Airplane | 5000 |
| Automobile | 5000 |
| Bird | 5000 |
| Cat | 5000 |
| Deer | 5000 |
| Dog | 5000 |
| Frog | 5000 |
| Horse | 5000 |
| Ship | 5000 |
| Truck | 5000 |
# Visualise Distribution of Image Classes
gnnf.plot_counts(class_count_dict)
gnnf.plot_pie_chart(class_count_dict)
Insights:
Dataset Visualization
Display sample of images from each of the 10 classes.
# Create figure & set size
fig, axes = plt.subplots(10, 10, figsize=(30, 30))
for i in range(len(class_names)):
class_indices = np.where(y_train == i)[0]
# Randomly select ten images
random_indices = np.random.choice(class_indices, 10, replace=False)
for j, image_index in enumerate(random_indices):
axes[i, j].imshow(eda_data[image_index])
# Load image
axes[i, j].axis('off')
axes[i, j].set_title(class_names[i])
plt.suptitle('10 Random Images Per Class', fontsize=25)
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()
Insights:
Image Averaging for Pixel Distribution
gnnf.average_image(eda_data)
Insights:
gnnf.average_images_per_class(eda_data, y_train, class_names)
Insights:
Currently, the input pixel sizes are in the range [0, 255].
Hence, we are going to rescale the pixel values to the range [-1,1], so that the model can train more efficiently as pixel inputs with large integer values can slow down the training process.
# Scale from [0,255] to [-1,1]
X_train_rescaled = (X_train / 127.5 - 1.).astype('float32')
# Obtain Pixel Statistics After Rescaling
min = np.min(X_train_rescaled, axis=(0, 1, 2))
max = np.max(X_train_rescaled, axis=(0, 1, 2))
mean = np.mean(X_train_rescaled, axis=(0, 1, 2))
std = np.std(X_train_rescaled, axis=(0, 1, 2))
# Print Statistics
print("Pixel Statistics for the original train dataset:")
print(f"Minimum pixel value: {min}")
print(f"Maximum pixel value: {max}")
print(f"Mean pixel value: {mean}")
print(f"Standard deviation of pixel values: {std}\n")
Pixel Statistics for the original train dataset: Minimum pixel value: [-1. -1. -1.] Maximum pixel value: [1. 1. 1.] Mean pixel value: [-0.01720063 -0.03568322 -0.10693764] Standard deviation of pixel values: [0.49406543 0.48696858 0.52317506]
# Plot first image from eda_data
plt.subplot(1, 2, 1)
plt.imshow(eda_data[0])
plt.title('Original Image')
# Plot first image from X_train_dataAug
plt.subplot(1, 2, 2)
plt.imshow(X_train_rescaled[0])
plt.title('Rescaled Image')
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Finally, we are going to convert our finalized datasets to Tensor, from ndarray, for model training.
# Convert frpm ndarray to Tensor
# X_train_rescaled_final = tf.convert_to_tensor((X_train_rescaled))
y_train = tf.convert_to_tensor((y_train))
X_train = tf.data.Dataset.from_tensor_slices((X_train_rescaled, y_train))
X_train = X_train.shuffle(1000).batch(32, drop_remainder=True)
X_train
<BatchDataset element_spec=(TensorSpec(shape=(32, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 1), dtype=tf.uint8, name=None))>
y_test = tf.convert_to_tensor((y_test))
X_test_rescaled = (X_test / 127.5 - 1.).astype('float32')
X_test = tf.data.Dataset.from_tensor_slices((X_test_rescaled, y_test))
X_test = X_test.shuffle(1000).batch(32, drop_remainder=True)
As both the Generator and Discriminator are trained together to maintain an equilibirum in a zero-sum game, there is no objective loss function alone to evaluate the performance of the two models.
When measuring the performance of the GAN models, there are two properties to evaluate:
A number of qualitative and quantitative techniques has been developed to evaluate the performance of the models, based on the quality and diversity of the generated images.
Sources:
https://machinelearningmastery.com/how-to-evaluate-generative-adversarial-networks/
https://towardsdatascience.com/on-the-evaluation-of-generative-adversarial-networks-b056ddcdfd3a
To evaluate our GAN models, we will be using:
Fréchet Inception Distance (FID) evaluates the quality of generated images by calculating the distance between feature vectors calculated for real and generated images.
The FID score summarizes the similarity between the real and fake images in terms of statistics on computer vision features of the raw images, calculated by feature extractors. The most common feature extractor is the Inception-v3 classifier, which is pre-trained on ImageNet. By excluding the output layer, we extract the feature maps from the embeddings of the real and fake images. These embeddings are two multivariate normal distributions, which is compared using Wasserstein-2 distance.
\begin{aligned} FID &= \left\| \mu_r - \mu_g \right\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2\left(\Sigma_r \Sigma_g\right)^{\frac{1}{2}}\right)\\ \text{where}\\ \mu_r &\text{ is the feature-wise mean of the real images.} \\ \mu_g &\text{ is the feature-wise mean of the generated images.} \\ \Sigma_r &\text{ is the covariance matrix of the real images.} \\ \Sigma_g &\text{ is the covariance matrix of the generated images.} \\ \text{Tr} &\text{ denotes the trace of a matrix, which is the sum of all the diagonal elements.} \\ \end{aligned}It is to note that FID has its downsides. It uses a pre-trained Inception model, which may not capture all features, hence introducing biasness depending if the training data differs greatly from the domain of generated images. Moreover, it needs a large dataset to be accurate as it uses limited statistics of only mean and convariance.
Sources:
https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/
https://www.oreilly.com/library/view/generative-adversarial-networks/9781789136678/9bf2e543-8251-409e-a811-77e55d0dc021.xhtml
https://www.techtarget.com/searchenterpriseai/definition/Frechet-inception-distance-FID
def calcFID(input_images, num_images=1000):
inceptionModel = InceptionV3(
include_top=False,
weights="imagenet",
pooling='avg',
)
def scale_images(images, new_shape):
images_list = []
for image in images:
new_image = resize(image, new_shape, anti_aliasing=True)
images_list.append(new_image)
return np.array(images_list)
def calculate_fid(model, images1, images2):
act1 = model.predict(images1)
act2 = model.predict(images2)
mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
ssdiff = np.sum((mu1 - mu2)**2.0)
covmean = sqrtm(sigma1.dot(sigma2))
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
(real_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
np.random.shuffle(real_images)
real_images = real_images[:num_images]
real_images = real_images.astype('float32')
# print(real_images)
real_images = (real_images / 127.5 - 1)
real_images = scale_images(real_images, (299, 299))
generated_images = input_images.astype('float32')
generated_images = scale_images(generated_images, (299, 299))
fid = calculate_fid(inceptionModel, real_images, generated_images)
return fid
To help us further determine the quality of the model, we shall make use of the Kullback-Leibler Divergence metric (KL Divergence) to help us quantitatively evaluate the results of the model. KL Divergence is a statistical measure that quantifies how different one probability distribution is from another reference probability distribution. It is also known as relative entropy. KL Divergence is non-negative and asymmetric, meaning that the divergence of P from Q is not the same as the divergence of Q from P. It is often used in the field of machine learning to measure the difference between the predicted and true probability distributions of data, or to compare a model's distribution with the empirical distribution of the data. KL-divergence can be used within the generator loss function to encourage diverse outputs. This typically involves calculating the KL-divergence between the generated data distribution and a desired target distribution, like a uniform distribution over the data space.
The formula for KL Divergence is as follows: \begin{align*} &D_{KL}(P \parallel Q) = \sum_{i} P(i) \log \left( \frac{P(i)}{Q(i)} \right) \\ &\text{where}\\ &P \text{ is the true distribution} \\ &Q \text{ is the distribution to compare against} \\ &\sum_{i} \text{ is taken over all possible events} \end{align*}
This equation essentially sums up the product of the probabilities from the true distribution ( P(i) ) and the logarithm of the ratio of probabilities from the true distribution to the comparison distribution ( Q(i) ). It is a measure of the information gained about ( P ) when one uses ( Q ) as the approximation. For our use case, we would want to see a low KL Divergence score, as it means that the distribution of the generated data is very close to the distribution of the real data. This is the goal of a well-functioning GAN – to generate data that is indistinguishable from real data. It also gives us a metric to directly compare models with, and also provides us with feedback on how the model is performing during training.
GANs are an approach to generative modelling using deep learning methods, like CNNs.
GANs train a generative model by approaching the problem as a supervised learning problem with two sub-models:
Training Process
These two models are trained together in a zero-sum game, until the Disciminator is fooled about half the time.
The Generator generates a batch of samples, and along with real samples from the dataset, are provided to the Discriminator to be classified as real (1) or fake (0).
While one model trains, the other model's weights remain constant, otherwise the Generator would be trying to hit a moving target & might never converge. The training proceeds in alternating periods, where each model take turns training for one or more epochs.
Loss Functions
Loss functions reflect the distance between the distribution of generated data and the distribution of real data.
Through backpropagation, the Discriminator's weights are updated from the discriminator loss to get better at discriminating, while the Generator's weights are updated from the generator loss based on the Discriminator classification, which is how well or not the generated samples fool the Discriminator.
Zero-sum game refers to when the Discriminator successfully identifies the real and fake samples, it is rewarded or no change is needed to the model parameters, whereas the the Generator is penalized with large updates to the model parameters, and vice versa.
Convergence
At a limit, the Discriminator cannot tell the difference between perfect replicas and the real images, hence predicts "unsure" (e.g. 50% for real and fake). If the GAN continues training with random feedback from the Discriminator, the model might collapse. For a GAN, convergence is often a fleeting, rather than stable, state.
Sources:
https://developers.google.com/machine-learning/gan/gan_structure
https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-a-cifar-10-small-object-photographs-from-scratch/
https://machinelearningmastery.com/how-to-code-generative-adversarial-network-hacks/
To start tackling the task, we first design a baseline template class, which we can then build off of to create models which use different architectures. As we have already done our preprocessing, we can go straight onto modelling.
To start, we shall use the DCGAN architecture as a baseline, then move onto the cGAN, WGAN, and Hinge GAN architectures. From there, we can pick the best architecture for our task, and hypertune it further to determine our final GAN model to generate images from.
We shall also make use of the Binary Crossentropy loss function to help us here, over other loss functions such as Sparse Categorical Crossentropy, as it is designed for binary classification problems, in which the end goal is to predict one out of two possible outcomes. This fits our task, as in GANs, both the generator and discriminator are involved in a binary classification task, with the discriminator's task being to classify inputs as real or fake, and the generator's job being to generate outputs which are classified as real by the discriminator. The formula for Binary Cross-Entropy (BCE) is as follows: \begin{align*} &BCE = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] \\ &\text{where}\\ &N \text{ is the number of observations,} \\ &y_i \text{ is the actual label of the } i^{\text{th}} \text{ observation,} \\ &\hat{y}_i \text{ is the predicted probability that the } i^{\text{th}} \text{ observation is of the positive class.} \end{align*}
Furthermore, we shall use the Adam optimizer first. ADAptive Moment Estimation is a extension of Stochastic Gradient Descent. It combines ideas from two other optimization algorithms, namely Momentum and RMSProp. It is a good choice for us to start with, given it's "self-learning" properties, which means that it's learning rates are adaptive, relative to how frequently a parameter gets updated while training, making it suitable for problems with sparse gradients or with noisy data. It is also omputationally efficient with relatively low memory requirements, making it suitable for problems with large datasets or parameters.
We create the template class here, as well as implement a custom callback function so that we can track the model's performance easier.
class CustomCallback(Callback):
def __init__(self, d_losses, g_losses, kl_div, model, filepath):
super(CustomCallback, self).__init__()
self.d_losses = d_losses
self.g_losses = g_losses
self.kl_div = kl_div
self.model = model
self.filepath = filepath
def on_epoch_end(self, epoch, logs=None):
gan_model = self.model
generator = gan_model.generator
d_loss = logs.get('d_loss')
g_loss = logs.get('g_loss')
kl_div = logs.get("kl_divergence")
self.d_losses = np.array(list(self.d_losses))
self.g_losses = np.array(list(self.g_losses))
self.kl_div = np.array(list(self.kl_div))
self.d_losses = np.append(self.d_losses, d_loss)
self.g_losses = np.append(self.g_losses, g_loss)
self.kl_div = np.append(self.kl_div, kl_div)
generated_images, generated_labels = gan_model.generate_fake_samples(self.model, generator = generator)
self.model.save_plot(generated_images, epoch, self.d_losses, self.g_losses, self.kl_div, self.filepath)
self.model.save_weights(f"{self.filepath}weights/weights_{epoch}.h5")
class GAN_template(Model):
def __init__(self, latent_dim):
super().__init__()
self.discriminator = self.define_discriminator()
self.generator = self.define_generator(latent_dim)
self.latent_dim = latent_dim
self.d_loss_tracker = Mean(name="d_loss")
self.g_loss_tracker = Mean(name="g_loss")
self.kl_divergence_tracker = Mean(name = "kl_divergence")
self.g_loss_list = []
self.d_loss_list = []
self.kl_div_list = []
@staticmethod
def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
print(examples)
fig = plt.figure(figsize=(15, 10))
gs = fig.add_gridspec(4, 6, height_ratios=[1, 1, 1, 1.2], width_ratios=[1, 1, 1, 1, 1, 1], hspace=0.4, wspace=0.4)
examples = (examples + 1) / 2.0
for i in range(3 * 6):
ax = fig.add_subplot(gs[i // 6, i % 6])
ax.axis('off')
ax.imshow(examples[i])
ax_loss = fig.add_subplot(gs[3, 0:2])
ax_loss.plot(d_losses, label="Discriminator Loss")
ax_loss.set_title("Discriminator Loss")
ax_g_loss = fig.add_subplot(gs[3, 2:4])
ax_g_loss.plot(g_losses, label="Generator Loss")
ax_g_loss.set_title("Generator Loss")
ax_kl_div = fig.add_subplot(gs[3, 4:6])
ax_kl_div.plot(kl_div, label="KL Divergence")
ax_kl_div.set_title("KL Divergence")
plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.92)
plt.tight_layout()
plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
plt.close()
def kl_divergence(self, real_data, generated_data):
epsilon = 1e-10
# real_data_flat = tf.reshape(real_data, [real_data.shape[0], -1])
# generated_data_flat = tf.reshape(generated_data, [generated_data.shape[0], -1])
real_data_hist = tf.histogram_fixed_width(real_data, [0, 1], nbins=256)
generated_data_hist = tf.histogram_fixed_width(generated_data, [0, 1], nbins=256)
real_data_prob = real_data_hist / tf.reduce_sum(real_data_hist)
generated_data_prob = generated_data_hist / tf.reduce_sum(generated_data_hist)
epsilon = 1e-10
real_data_prob += epsilon
generated_data_prob += epsilon
kl_div = tf.reduce_sum(real_data_prob * tf.math.log(real_data_prob / generated_data_prob+epsilon))
return kl_div
@staticmethod
def generate_fake_samples(self, generator, n_samples=18, latent_dim=100):
x_input = np.random.randn(latent_dim * n_samples)
x_input = x_input.reshape(n_samples, latent_dim)
X = generator.predict(x_input, verbose=0)
y = np.zeros((n_samples, 1))
return X, y
def define_discriminator(self, in_shape=(32,32,3)):
pass
def define_generator(self, latent_dim):
pass
def compile(self, d_optimizer, g_optimizer, loss_fn):
super().compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images_dataset):
for real_images in real_images_dataset:
if isinstance(real_images, tuple):
real_images = real_images[0]
batch_size = real_images[0].shape[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
generated_images = self.generator(random_latent_vectors)
combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)
labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
labels += 0.05 * tf.random.uniform(tf.shape(labels))
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
misleading_labels = tf.zeros((batch_size, 1))
with tf.GradientTape() as tape:
generated_images = self.generator(random_latent_vectors)
predictions = self.discriminator(generated_images)
g_loss = self.loss_fn(misleading_labels, predictions)
kl_loss = self.kl_divergence(real_images, generated_images)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics and return their value.
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
self.kl_divergence_tracker.update_state(kl_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
"kl_divergence": self.kl_divergence_tracker.result()
}
DCGAN uses convolutional and convolutional-transpose layers in the generator and discriminator, respectively. It was proposed by Radford et. al. in the paper Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks.
The Discriminator consists of strided convolution layers of 2x2 to downsample the input image, batch normalization layers, and LeakyRelu as activation function. We have replaced pooling layers with strided convolutions, as a strided convolution can decrease the dimension by jumping multiple pixels between convolutions instead of sliding the kernel one-by-one. The discriminator takes a 3x64x64 input image. The discriminator is trained to minimize the binary cross entropy loss function, which is suitable for binary classification.
The Generator consists of convolutional-transpose layers, batch normalization layers, and ReLU activations. The output will be a 3x64x64 RGB image.
Other key features of DCGAN's include the use of ReLU activation functions in the generator (except for the output layer which uses Tanh), and the elimination of fully connected layers and directly connect the output to the convoluational layers where possible.
DCGANs have been used in various applications like photo editing, art creation, image super-resolution, and more. They are particularly noted for their ability to generate high-quality images and learn hierarchical representations of objects in images.
Insights:
class DCGAN(GAN_template):
def __init__(self, latent_dim):
super().__init__(latent_dim)
def define_discriminator(self, in_shape=(32,32,3)):
model = Sequential()
model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.0015)))
model.add(LeakyReLU(alpha=0.2))
# Classifier
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
# Compile Model
model.compile(loss='binary_crossentropy', optimizer = Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
return model
def define_generator(self, latent_dim):
model = Sequential()
# Nodes to represent a low-resolution version of the output image
n_nodes = 256 * 4 * 4
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((4, 4, 256))) # Activations from these nodes can then be reshaped into something image-like, e.g. 256 different 4 x 4 feature maps
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')) # Combines UpSampling & Conv2D layers, stride of 2x2 quadruples area of the input feature maps
model.add(LeakyReLU(alpha=0.2))
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(3, (3,3), activation='tanh', padding='same')) # Three filters for three color channels
return model
dcgan = DCGAN(latent_dim=100)
dcgan.compile(
d_optimizer=Adam(learning_rate=0.0003),
g_optimizer=Adam(learning_rate=0.0003),
loss_fn=BinaryCrossentropy(from_logits=True),
)
dcgan_callback = CustomCallback(d_losses = dcgan.d_loss_list, g_losses = dcgan.g_loss_list, kl_div = dcgan.kl_div_list,model = dcgan, filepath = "output/models/dcgan/")
dcgan.fit(X_train, epochs = 200, callbacks = [dcgan_callback])
Epoch 1/200 1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.4620 - g_loss: 2.9281 - kl_divergence: 0.9407 Epoch 2/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5346 - g_loss: 1.7597 - kl_divergence: 0.5963 Epoch 3/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5822 - g_loss: 1.4851 - kl_divergence: 0.4541 Epoch 4/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5950 - g_loss: 1.5685 - kl_divergence: 0.5027 Epoch 5/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5956 - g_loss: 1.4207 - kl_divergence: 0.4210 Epoch 6/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5994 - g_loss: 1.3362 - kl_divergence: 0.4004 Epoch 7/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6078 - g_loss: 1.5065 - kl_divergence: 0.4278 Epoch 8/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5852 - g_loss: 1.3864 - kl_divergence: 0.4132 Epoch 9/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6130 - g_loss: 1.2438 - kl_divergence: 0.3762 Epoch 10/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6243 - g_loss: 1.2382 - kl_divergence: 0.3712 Epoch 11/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6193 - g_loss: 1.1890 - kl_divergence: 0.3827 Epoch 12/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6070 - g_loss: 1.3629 - kl_divergence: 0.3747 Epoch 13/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6219 - g_loss: 1.2310 - kl_divergence: 0.3725 Epoch 14/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.6071 - g_loss: 1.2327 - kl_divergence: 0.3730 Epoch 15/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6028 - g_loss: 1.3161 - kl_divergence: 0.3580 Epoch 16/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6075 - g_loss: 1.2134 - kl_divergence: 0.3661 Epoch 17/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6232 - g_loss: 1.1583 - kl_divergence: 0.3574 Epoch 18/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6021 - g_loss: 1.3342 - kl_divergence: 0.3630 Epoch 19/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6132 - g_loss: 1.2154 - kl_divergence: 0.3565 Epoch 20/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6152 - g_loss: 1.2352 - kl_divergence: 0.3532 Epoch 21/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6013 - g_loss: 1.3076 - kl_divergence: 0.3552 Epoch 22/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6252 - g_loss: 1.3256 - kl_divergence: 0.3809 Epoch 23/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5911 - g_loss: 1.2972 - kl_divergence: 0.3654 Epoch 24/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5913 - g_loss: 1.3727 - kl_divergence: 0.3902 Epoch 25/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5696 - g_loss: 1.4636 - kl_divergence: 0.3603 Epoch 26/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5680 - g_loss: 1.4073 - kl_divergence: 0.3575 Epoch 27/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5791 - g_loss: 1.3054 - kl_divergence: 0.3801 Epoch 28/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5646 - g_loss: 1.3885 - kl_divergence: 0.3635 Epoch 29/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5773 - g_loss: 1.3255 - kl_divergence: 0.3781 Epoch 30/200 1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.5655 - g_loss: 1.3874 - kl_divergence: 0.3788 Epoch 31/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5847 - g_loss: 1.3226 - kl_divergence: 0.3655 Epoch 32/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5868 - g_loss: 1.3944 - kl_divergence: 0.3694 Epoch 33/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5600 - g_loss: 1.4091 - kl_divergence: 0.3805 Epoch 34/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5813 - g_loss: 1.2923 - kl_divergence: 0.3769 Epoch 35/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5915 - g_loss: 1.2569 - kl_divergence: 0.3660 Epoch 36/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5841 - g_loss: 1.2796 - kl_divergence: 0.3703 Epoch 37/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5931 - g_loss: 1.2333 - kl_divergence: 0.3721 Epoch 38/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5917 - g_loss: 1.2376 - kl_divergence: 0.3758 Epoch 39/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6021 - g_loss: 1.2194 - kl_divergence: 0.3736 Epoch 40/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5958 - g_loss: 1.2724 - kl_divergence: 0.3665 Epoch 41/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5923 - g_loss: 1.2394 - kl_divergence: 0.3566 Epoch 42/200 1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.6100 - g_loss: 1.1156 - kl_divergence: 0.3535 Epoch 43/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6049 - g_loss: 1.1759 - kl_divergence: 0.3605 Epoch 44/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6048 - g_loss: 1.2029 - kl_divergence: 0.3526 Epoch 45/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5992 - g_loss: 1.2041 - kl_divergence: 0.3490 Epoch 46/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.6130 - g_loss: 1.1736 - kl_divergence: 0.3476 Epoch 47/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.6043 - g_loss: 1.1408 - kl_divergence: 0.3466 Epoch 48/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5928 - g_loss: 1.2183 - kl_divergence: 0.3482 Epoch 49/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.6014 - g_loss: 1.1852 - kl_divergence: 0.3518 Epoch 50/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5988 - g_loss: 1.1598 - kl_divergence: 0.3476 Epoch 51/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5857 - g_loss: 1.2534 - kl_divergence: 0.3535 Epoch 52/200 1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5901 - g_loss: 1.2218 - kl_divergence: 0.3512 Epoch 53/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5925 - g_loss: 1.2221 - kl_divergence: 0.3480 Epoch 54/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5746 - g_loss: 1.3319 - kl_divergence: 0.3488 Epoch 55/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5755 - g_loss: 1.2579 - kl_divergence: 0.3497 Epoch 56/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5878 - g_loss: 1.2517 - kl_divergence: 0.3483 Epoch 57/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5866 - g_loss: 1.1762 - kl_divergence: 0.3628 Epoch 58/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5816 - g_loss: 1.2031 - kl_divergence: 0.3559 Epoch 59/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5997 - g_loss: 1.1481 - kl_divergence: 0.3522 Epoch 60/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5970 - g_loss: 1.1498 - kl_divergence: 0.3567 Epoch 61/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6024 - g_loss: 1.1170 - kl_divergence: 0.3555 Epoch 62/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5952 - g_loss: 1.1467 - kl_divergence: 0.3547 Epoch 63/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6071 - g_loss: 1.1054 - kl_divergence: 0.3600 Epoch 64/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6104 - g_loss: 1.1025 - kl_divergence: 0.3544 Epoch 65/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5994 - g_loss: 1.1217 - kl_divergence: 0.3597 Epoch 66/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6006 - g_loss: 1.1255 - kl_divergence: 0.3495 Epoch 67/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6018 - g_loss: 1.1059 - kl_divergence: 0.3489 Epoch 68/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6009 - g_loss: 1.1285 - kl_divergence: 0.3471 Epoch 69/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6035 - g_loss: 1.1244 - kl_divergence: 0.3538 Epoch 70/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5891 - g_loss: 1.1846 - kl_divergence: 0.3511 Epoch 71/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5823 - g_loss: 1.1779 - kl_divergence: 0.3523 Epoch 72/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5877 - g_loss: 1.1787 - kl_divergence: 0.3544 Epoch 73/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5797 - g_loss: 1.2198 - kl_divergence: 0.3522 Epoch 74/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5797 - g_loss: 1.1848 - kl_divergence: 0.3537 Epoch 75/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5837 - g_loss: 1.1969 - kl_divergence: 0.3542 Epoch 76/200 1562/1562 [==============================] - 37s 24ms/step - d_loss: 0.5762 - g_loss: 1.2353 - kl_divergence: 0.3533 Epoch 77/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5787 - g_loss: 1.2061 - kl_divergence: 0.3569 Epoch 78/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5753 - g_loss: 1.2357 - kl_divergence: 0.3573 Epoch 79/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5752 - g_loss: 1.2425 - kl_divergence: 0.3561 Epoch 80/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5699 - g_loss: 1.2389 - kl_divergence: 0.3590 Epoch 81/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5881 - g_loss: 1.2102 - kl_divergence: 0.3481 Epoch 82/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5780 - g_loss: 1.2019 - kl_divergence: 0.3618 Epoch 83/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5682 - g_loss: 1.2382 - kl_divergence: 0.3646 Epoch 84/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5718 - g_loss: 1.2492 - kl_divergence: 0.3620 Epoch 85/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5714 - g_loss: 1.2350 - kl_divergence: 0.3624 Epoch 86/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5677 - g_loss: 1.2485 - kl_divergence: 0.3618 Epoch 87/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5691 - g_loss: 1.2545 - kl_divergence: 0.3598 Epoch 88/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5729 - g_loss: 1.2548 - kl_divergence: 0.3557 Epoch 89/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5694 - g_loss: 1.2296 - kl_divergence: 0.3610 Epoch 90/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5667 - g_loss: 1.2560 - kl_divergence: 0.3606 Epoch 91/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5607 - g_loss: 1.2710 - kl_divergence: 0.3625 Epoch 92/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5805 - g_loss: 1.2325 - kl_divergence: 0.3580 Epoch 93/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2847 - kl_divergence: 0.3604 Epoch 94/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5592 - g_loss: 1.2449 - kl_divergence: 0.3561 Epoch 95/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5701 - g_loss: 1.2569 - kl_divergence: 0.3565 Epoch 96/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5673 - g_loss: 1.2507 - kl_divergence: 0.3584 Epoch 97/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5537 - g_loss: 1.2816 - kl_divergence: 0.3558 Epoch 98/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5629 - g_loss: 1.2954 - kl_divergence: 0.3567 Epoch 99/200 1562/1562 [==============================] - 37s 24ms/step - d_loss: 0.5559 - g_loss: 1.3044 - kl_divergence: 0.3575 Epoch 100/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5604 - g_loss: 1.2645 - kl_divergence: 0.3512 Epoch 101/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2886 - kl_divergence: 0.3518 Epoch 102/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5627 - g_loss: 1.2622 - kl_divergence: 0.3646 Epoch 103/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5586 - g_loss: 1.2667 - kl_divergence: 0.3632 Epoch 104/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5575 - g_loss: 1.2881 - kl_divergence: 0.3569 Epoch 105/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2592 - kl_divergence: 0.3560 Epoch 106/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5533 - g_loss: 1.3084 - kl_divergence: 0.3587 Epoch 107/200 1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5541 - g_loss: 1.2693 - kl_divergence: 0.3585 Epoch 108/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5517 - g_loss: 1.3361 - kl_divergence: 0.3595 Epoch 109/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5534 - g_loss: 1.2923 - kl_divergence: 0.3525 Epoch 110/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5522 - g_loss: 1.3019 - kl_divergence: 0.3573 Epoch 111/200 1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5476 - g_loss: 1.3247 - kl_divergence: 0.3599 Epoch 112/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5421 - g_loss: 1.3081 - kl_divergence: 0.3577 Epoch 113/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5627 - g_loss: 1.2644 - kl_divergence: 0.3509 Epoch 114/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5542 - g_loss: 1.2908 - kl_divergence: 0.3541 Epoch 115/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5572 - g_loss: 1.2493 - kl_divergence: 0.3529 Epoch 116/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5610 - g_loss: 1.3004 - kl_divergence: 0.3507 Epoch 117/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5583 - g_loss: 1.2566 - kl_divergence: 0.3528 Epoch 118/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5556 - g_loss: 1.2753 - kl_divergence: 0.3613 Epoch 119/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5610 - g_loss: 1.2382 - kl_divergence: 0.3560 Epoch 120/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5558 - g_loss: 1.2500 - kl_divergence: 0.3642 Epoch 121/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5601 - g_loss: 1.2719 - kl_divergence: 0.3568 Epoch 122/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5512 - g_loss: 1.3111 - kl_divergence: 0.3615 Epoch 123/200 1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5610 - g_loss: 1.2511 - kl_divergence: 0.3686 Epoch 124/200 1562/1562 [==============================] - 34s 22ms/step - d_loss: 0.5603 - g_loss: 1.2874 - kl_divergence: 0.3576 Epoch 125/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5575 - g_loss: 1.2420 - kl_divergence: 0.3533 Epoch 126/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5498 - g_loss: 1.2881 - kl_divergence: 0.3556 Epoch 127/200 1562/1562 [==============================] - 38s 24ms/step - d_loss: 0.5613 - g_loss: 1.2682 - kl_divergence: 0.3564 Epoch 128/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5518 - g_loss: 1.2522 - kl_divergence: 0.3571 Epoch 129/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5522 - g_loss: 1.2916 - kl_divergence: 0.3568 Epoch 130/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5506 - g_loss: 1.2723 - kl_divergence: 0.3549 Epoch 131/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5537 - g_loss: 1.2758 - kl_divergence: 0.3550 Epoch 132/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5510 - g_loss: 1.2763 - kl_divergence: 0.3569 Epoch 133/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5469 - g_loss: 1.3028 - kl_divergence: 0.3553 Epoch 134/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5610 - g_loss: 1.2862 - kl_divergence: 0.3518 Epoch 135/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5464 - g_loss: 1.2623 - kl_divergence: 0.3509 Epoch 136/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5381 - g_loss: 1.3185 - kl_divergence: 0.3563 Epoch 137/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5447 - g_loss: 1.3430 - kl_divergence: 0.3515 Epoch 138/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5464 - g_loss: 1.2760 - kl_divergence: 0.3592 Epoch 139/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5436 - g_loss: 1.3322 - kl_divergence: 0.3521 Epoch 140/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5414 - g_loss: 1.2778 - kl_divergence: 0.3591 Epoch 141/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5434 - g_loss: 1.3041 - kl_divergence: 0.3534 Epoch 142/200 1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5407 - g_loss: 1.3358 - kl_divergence: 0.3551 Epoch 143/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5375 - g_loss: 1.3016 - kl_divergence: 0.3617 Epoch 144/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5436 - g_loss: 1.3120 - kl_divergence: 0.3521 Epoch 145/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5414 - g_loss: 1.2944 - kl_divergence: 0.3550 Epoch 146/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5414 - g_loss: 1.3288 - kl_divergence: 0.3521 Epoch 147/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5385 - g_loss: 1.3578 - kl_divergence: 0.3557 Epoch 148/200 1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5399 - g_loss: 1.2914 - kl_divergence: 0.3550 Epoch 149/200 1562/1562 [==============================] - 34s 22ms/step - d_loss: 0.5329 - g_loss: 1.3307 - kl_divergence: 0.3606 Epoch 150/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5351 - g_loss: 1.3280 - kl_divergence: 0.3570 Epoch 151/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5405 - g_loss: 1.3040 - kl_divergence: 0.3544 Epoch 152/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5433 - g_loss: 1.2975 - kl_divergence: 0.3530 Epoch 153/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5412 - g_loss: 1.3091 - kl_divergence: 0.3564 Epoch 154/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5423 - g_loss: 1.3010 - kl_divergence: 0.3527 Epoch 155/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5413 - g_loss: 1.3058 - kl_divergence: 0.3554 Epoch 156/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5329 - g_loss: 1.3560 - kl_divergence: 0.3547 Epoch 157/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5359 - g_loss: 1.3165 - kl_divergence: 0.3545 Epoch 158/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5389 - g_loss: 1.3312 - kl_divergence: 0.3525 Epoch 159/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5350 - g_loss: 1.3174 - kl_divergence: 0.3555 Epoch 160/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5302 - g_loss: 1.3220 - kl_divergence: 0.3563 Epoch 161/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5321 - g_loss: 1.3518 - kl_divergence: 0.3523 Epoch 162/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5321 - g_loss: 1.3550 - kl_divergence: 0.3504 Epoch 163/200 1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5275 - g_loss: 1.3714 - kl_divergence: 0.3541 Epoch 164/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5200 - g_loss: 1.3458 - kl_divergence: 0.3566 Epoch 165/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5189 - g_loss: 1.3846 - kl_divergence: 0.3608 Epoch 166/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5202 - g_loss: 1.3991 - kl_divergence: 0.3591 Epoch 167/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5233 - g_loss: 1.3678 - kl_divergence: 0.3558 Epoch 168/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5234 - g_loss: 1.3949 - kl_divergence: 0.3547 Epoch 169/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5186 - g_loss: 1.3869 - kl_divergence: 0.3568 Epoch 170/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5158 - g_loss: 1.3899 - kl_divergence: 0.3619 Epoch 171/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5196 - g_loss: 1.4413 - kl_divergence: 0.3572 Epoch 172/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5201 - g_loss: 1.3844 - kl_divergence: 0.3572 Epoch 173/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5183 - g_loss: 1.3875 - kl_divergence: 0.3585 Epoch 174/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5119 - g_loss: 1.4424 - kl_divergence: 0.3580 Epoch 175/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5131 - g_loss: 1.4047 - kl_divergence: 0.3569 Epoch 176/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5080 - g_loss: 1.4314 - kl_divergence: 0.3531 Epoch 177/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5145 - g_loss: 1.4126 - kl_divergence: 0.3548 Epoch 178/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5111 - g_loss: 1.4162 - kl_divergence: 0.3591 Epoch 179/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5136 - g_loss: 1.4060 - kl_divergence: 0.3570 Epoch 180/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5082 - g_loss: 1.4531 - kl_divergence: 0.3605 Epoch 181/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5110 - g_loss: 1.4555 - kl_divergence: 0.3600 Epoch 182/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5081 - g_loss: 1.3883 - kl_divergence: 0.3598 Epoch 183/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5035 - g_loss: 1.4574 - kl_divergence: 0.3585 Epoch 184/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5173 - g_loss: 1.4260 - kl_divergence: 0.3591 Epoch 185/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5080 - g_loss: 1.4357 - kl_divergence: 0.3574 Epoch 186/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5031 - g_loss: 1.4888 - kl_divergence: 0.3591 Epoch 187/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5039 - g_loss: 1.4264 - kl_divergence: 0.3632 Epoch 188/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5063 - g_loss: 1.4129 - kl_divergence: 0.3643 Epoch 189/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5042 - g_loss: 1.4980 - kl_divergence: 0.3597 Epoch 190/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4995 - g_loss: 1.4302 - kl_divergence: 0.3591 Epoch 191/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5021 - g_loss: 1.4641 - kl_divergence: 0.3610 Epoch 192/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5033 - g_loss: 1.5035 - kl_divergence: 0.3614 Epoch 193/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4998 - g_loss: 1.4920 - kl_divergence: 0.3562 Epoch 194/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5021 - g_loss: 1.4419 - kl_divergence: 0.3536 Epoch 195/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5064 - g_loss: 1.5291 - kl_divergence: 0.3603 Epoch 196/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5051 - g_loss: 1.4093 - kl_divergence: 0.3593 Epoch 197/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4958 - g_loss: 1.4799 - kl_divergence: 0.3589 Epoch 198/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5000 - g_loss: 1.4938 - kl_divergence: 0.3601 Epoch 199/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5001 - g_loss: 1.4814 - kl_divergence: 0.3581 Epoch 200/200 1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5042 - g_loss: 1.5149 - kl_divergence: 0.3614
<keras.callbacks.History at 0x229492c46d0>
class DCGAN(GAN_template):
def __init__(self, latent_dim):
super().__init__(latent_dim)
def define_discriminator(self, in_shape=(32,32,3)):
model = Sequential()
model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Downsample
model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.0015)))
model.add(LeakyReLU(alpha=0.2))
# Classifier
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
# Compile Model
model.compile(loss='binary_crossentropy', optimizer = Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
return model
def define_generator(self, latent_dim):
model = Sequential()
# Nodes to represent a low-resolution version of the output image
n_nodes = 256 * 4 * 4
model.add(Dense(n_nodes, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((4, 4, 256))) # Activations from these nodes can then be reshaped into something image-like, e.g. 256 different 4 x 4 feature maps
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')) # Combines UpSampling & Conv2D layers, stride of 2x2 quadruples area of the input feature maps
model.add(LeakyReLU(alpha=0.2))
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
# Upsample
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(3, (3,3), activation='tanh', padding='same')) # Three filters for three color channels
return model
input_folder = './output/models/dcgan/generated/' # Replace with your frames directory
output_file = 'output_video.mp4' # Replace with your desired output file path
gnnf.create_video_from_frames(input_folder, output_file)
From the above, we can see that the images generated by DCGAN are somewhat coherent, with makeshift objects being present in the imgaes. The images also have detailing to them, and some what resemble the images in the CIFAR-10 dataset.
However, we don't know which class of images the model is trying to predict. Hence, we shall try to solve with problem with the next model architecture.
Conditional Generative Adversarial Networks (cGANs) represent an advanced evolution in the realm of Generative Adversarial Networks (GANs), specifically designed for generating data samples under defined conditions. The foundational work on cGANs is attributed to Mirza and Osindero in their seminal paper "Conditional Generative Adversarial Nets".
In the architecture of cGANs, the generator and discriminator are both conditioned on additional information, such as labels or tags, which guide the data generation process. This conditional approach allows for the generation of targeted data samples, enhancing the versatility and effectiveness of the network.
Some architectural features of cGAN include the use of conditional information such as labels to help steer the data generation process to align with conditions such as generating images of a class. This information is also given to the discriminator to help access if the generated data aligns with the given conditions. In terms of activation functions, cGANs often employ similar activation techniques to DCGANs.
num_classes = 10
class cGAN(GAN_template):
def __init__(self, latent_dim):
super().__init__(latent_dim)
self.num_classes = num_classes
def define_discriminator(self, in_shape=(32,32,3)):
# Image input
image_input = Input(shape=in_shape)
label_input = Input(shape=(1,))
label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
label_embedding = Dense(np.prod(in_shape))(label_embedding)
label_embedding = Reshape(in_shape)(label_embedding)
concatenated = Concatenate()([image_input, label_embedding])
x = Conv2D(64, (3,3), padding='same')(concatenated)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(128, (3,3), strides=(2,2), padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.001, l2=0.001))(x)
x = LeakyReLU(alpha=0.2)(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=[image_input, label_input], outputs=output)
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
return model
def define_generator(self, latent_dim):
model = Sequential()
# Model for processing the labels
label_input = tf.keras.Input(shape=(1,), dtype='int32')
label_embedding = Embedding(num_classes, latent_dim)(label_input)
label_embedding = Flatten()(label_embedding)
latent_input = tf.keras.Input(shape=(latent_dim,))
merged_input = Concatenate()([latent_input, label_embedding])
# Sequential model for the generator
generator = Sequential([
Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
LeakyReLU(alpha=0.2),
Reshape((8, 8, 256)),
Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
])
# Pass the merged input through the generator model
generated_image = generator(merged_input)
# Final cGAN generator model
model = Model(inputs=[latent_input, label_input], outputs=generated_image)
return model
def train_step(self, data):
if isinstance(data, tuple):
real_images, real_labels = data
else:
real_images = data
real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)
batch_size = real_images[0].shape[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
real_images = tf.reshape(real_images, [batch_size, 32, 32, 3])
fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)
generated_images = self.generator([random_latent_vectors, fake_labels])
combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)
real_labels = tf.squeeze(real_labels)
combined_labels = tf.concat([tf.cast(fake_labels, 'uint8'), real_labels], axis=0) # Concatenate labels as well
discriminator_labels = tf.concat(
[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
)
with tf.GradientTape() as tape:
predictions = self.discriminator([combined_images, combined_labels])
d_loss = self.loss_fn(discriminator_labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
misleading_labels = tf.ones((batch_size, 1))
with tf.GradientTape() as tape:
generated_images = self.generator([random_latent_vectors, fake_labels])
predictions = self.discriminator([generated_images, fake_labels])
g_loss = self.loss_fn(misleading_labels, predictions)
kl_loss = self.kl_divergence(real_images, generated_images)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
self.kl_divergence_tracker.update_state(kl_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
"kl_divergence": self.kl_divergence_tracker.result()
}
@staticmethod
def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
examples = (examples + 1) / 2.0
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
for i in range(10*5): # 3 images per class, 10 classes
class_index = i // 5 # Determine class based on order
ax = fig.add_subplot(gs[i % 5, class_index])
ax.axis('off')
ax.imshow(examples[i])
if i % 5 == 0:
ax.set_title(class_names[class_index], fontsize=8)
# Plot for discriminator losses
ax_loss = fig.add_subplot(gs[5:8, 0:3])
ax_loss.plot(d_losses, label="Discriminator Loss")
ax_loss.set_title("Discriminator Loss")
# Plot for generator losses
ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
ax_g_loss.plot(g_losses, label="Generator Loss")
ax_g_loss.set_title("Generator Loss")
ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
ax_kl_div.plot(kl_div, label="KL Divergence")
ax_kl_div.set_title("KL Divergence")
plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
plt.tight_layout()
plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
plt.close()
@staticmethod
def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
X, y = [], []
for class_label in range(10): # CIFAR-10 has 10 classes
# Generate latent points
x_input = np.random.randn(latent_dim * n_samples)
x_input = x_input.reshape(n_samples, latent_dim)
# Create class labels
labels = np.full((n_samples, 1), class_label)
# Generate images
images = generator.predict([x_input, labels], verbose=0)
X.extend(images)
y.extend(labels)
return np.asarray(X), np.asarray(y)
cgan = cGAN(latent_dim=100)
cgan.compile(
d_optimizer=Adam(learning_rate=0.0003),
g_optimizer=Adam(learning_rate=0.0003),
loss_fn=BinaryCrossentropy(from_logits=True),
)
cgan_callback = CustomCallback(d_losses = cgan.d_loss_list, g_losses = cgan.g_loss_list, kl_div=cgan.kl_div_list, model = cgan, filepath = "output/models/cgan/")
cgan.fit(X_train, epochs = 200, callbacks = [cgan_callback])
Epoch 1/200 1562/1562 [==============================] - 58s 35ms/step - d_loss: 0.4840 - g_loss: 2.5247 - kl_divergence: 0.6751 Epoch 2/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.5458 - g_loss: 1.5353 - kl_divergence: 0.4259 Epoch 3/200 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.5698 - g_loss: 1.4624 - kl_divergence: 0.4204 Epoch 4/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5858 - g_loss: 1.4409 - kl_divergence: 0.3989 Epoch 5/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.6061 - g_loss: 1.2307 - kl_divergence: 0.3679 Epoch 6/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.6118 - g_loss: 1.5189 - kl_divergence: 0.4050 Epoch 7/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5794 - g_loss: 1.3202 - kl_divergence: 0.3703 Epoch 8/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5902 - g_loss: 1.3091 - kl_divergence: 0.3445 Epoch 9/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5745 - g_loss: 1.5144 - kl_divergence: 0.3493 Epoch 10/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5347 - g_loss: 1.4671 - kl_divergence: 0.3483 Epoch 11/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5201 - g_loss: 1.5850 - kl_divergence: 0.3519 Epoch 12/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5057 - g_loss: 1.6061 - kl_divergence: 0.3539 Epoch 13/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4607 - g_loss: 1.8370 - kl_divergence: 0.3662 Epoch 14/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4397 - g_loss: 1.9564 - kl_divergence: 0.3696 Epoch 15/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4052 - g_loss: 2.0620 - kl_divergence: 0.3698 Epoch 16/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3994 - g_loss: 2.1406 - kl_divergence: 0.3779 Epoch 17/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.4285 - g_loss: 1.8424 - kl_divergence: 0.3678 Epoch 18/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4436 - g_loss: 1.7842 - kl_divergence: 0.3660 Epoch 19/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4729 - g_loss: 1.5721 - kl_divergence: 0.3576 Epoch 20/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4682 - g_loss: 1.5564 - kl_divergence: 0.3533 Epoch 21/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4808 - g_loss: 1.5065 - kl_divergence: 0.3470 Epoch 22/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4863 - g_loss: 1.4848 - kl_divergence: 0.3475 Epoch 23/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5005 - g_loss: 1.4360 - kl_divergence: 0.3517 Epoch 24/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4891 - g_loss: 1.4734 - kl_divergence: 0.3485 Epoch 25/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4824 - g_loss: 1.5402 - kl_divergence: 0.3463 Epoch 26/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4841 - g_loss: 1.5289 - kl_divergence: 0.3442 Epoch 27/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4574 - g_loss: 1.5744 - kl_divergence: 0.3450 Epoch 28/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4942 - g_loss: 1.5094 - kl_divergence: 0.3430 Epoch 29/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4561 - g_loss: 1.5703 - kl_divergence: 0.3419 Epoch 30/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.4706 - g_loss: 1.5855 - kl_divergence: 0.3416 Epoch 31/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4739 - g_loss: 1.5134 - kl_divergence: 0.3412 Epoch 32/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4813 - g_loss: 1.5302 - kl_divergence: 0.3423 Epoch 33/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4710 - g_loss: 1.4704 - kl_divergence: 0.3450 Epoch 34/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4941 - g_loss: 1.4983 - kl_divergence: 0.3448 Epoch 35/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4807 - g_loss: 1.4557 - kl_divergence: 0.3456 Epoch 36/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4781 - g_loss: 1.4953 - kl_divergence: 0.3457 Epoch 37/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4887 - g_loss: 1.4842 - kl_divergence: 0.3469 Epoch 38/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4851 - g_loss: 1.5009 - kl_divergence: 0.3402 Epoch 39/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4746 - g_loss: 1.4589 - kl_divergence: 0.3485 Epoch 40/200 1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.4696 - g_loss: 1.5598 - kl_divergence: 0.3434 Epoch 41/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4795 - g_loss: 1.4673 - kl_divergence: 0.3431 Epoch 42/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4643 - g_loss: 1.5405 - kl_divergence: 0.3409 Epoch 43/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4746 - g_loss: 1.5465 - kl_divergence: 0.3392 Epoch 44/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4611 - g_loss: 1.5420 - kl_divergence: 0.3384 Epoch 45/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4592 - g_loss: 1.5817 - kl_divergence: 0.3415 Epoch 46/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4678 - g_loss: 1.6379 - kl_divergence: 0.3394 Epoch 47/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4598 - g_loss: 1.5374 - kl_divergence: 0.3413 Epoch 48/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4592 - g_loss: 1.6224 - kl_divergence: 0.3402 Epoch 49/200 1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.4636 - g_loss: 1.5936 - kl_divergence: 0.3361 Epoch 50/200 1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.4593 - g_loss: 1.5885 - kl_divergence: 0.3355 Epoch 51/200 1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.4506 - g_loss: 1.6277 - kl_divergence: 0.3372 Epoch 52/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4474 - g_loss: 1.6525 - kl_divergence: 0.3379 Epoch 53/200 1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.4449 - g_loss: 1.6906 - kl_divergence: 0.3371 Epoch 54/200 1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.4386 - g_loss: 1.6731 - kl_divergence: 0.3436 Epoch 55/200 1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.4413 - g_loss: 1.7110 - kl_divergence: 0.3354 Epoch 56/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4303 - g_loss: 1.7006 - kl_divergence: 0.3370 Epoch 57/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4274 - g_loss: 1.7386 - kl_divergence: 0.3360 Epoch 58/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4260 - g_loss: 1.7518 - kl_divergence: 0.3392 Epoch 59/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4263 - g_loss: 1.7579 - kl_divergence: 0.3404 Epoch 60/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4264 - g_loss: 1.7563 - kl_divergence: 0.3375 Epoch 61/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4216 - g_loss: 1.8059 - kl_divergence: 0.3370 Epoch 62/200 1562/1562 [==============================] - 55s 36ms/step - d_loss: 0.4181 - g_loss: 1.7909 - kl_divergence: 0.3396 Epoch 63/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4194 - g_loss: 1.8367 - kl_divergence: 0.3362 Epoch 64/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4147 - g_loss: 1.8420 - kl_divergence: 0.3386 Epoch 65/200 1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.4142 - g_loss: 1.8307 - kl_divergence: 0.3376 Epoch 66/200 1562/1562 [==============================] - 56s 35ms/step - d_loss: 0.4173 - g_loss: 1.8338 - kl_divergence: 0.3382 Epoch 67/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4117 - g_loss: 1.8508 - kl_divergence: 0.3372 Epoch 68/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4075 - g_loss: 1.8894 - kl_divergence: 0.3358 Epoch 69/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4035 - g_loss: 1.8991 - kl_divergence: 0.3349 Epoch 70/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4053 - g_loss: 1.8909 - kl_divergence: 0.3370 Epoch 71/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3993 - g_loss: 1.9298 - kl_divergence: 0.3385 Epoch 72/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3977 - g_loss: 1.9482 - kl_divergence: 0.3382 Epoch 73/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3979 - g_loss: 1.9846 - kl_divergence: 0.3388 Epoch 74/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3983 - g_loss: 1.9537 - kl_divergence: 0.3379 Epoch 75/200 1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.3905 - g_loss: 1.9815 - kl_divergence: 0.3371 Epoch 76/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3907 - g_loss: 2.0003 - kl_divergence: 0.3386 Epoch 77/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3883 - g_loss: 1.9946 - kl_divergence: 0.3390 Epoch 78/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3849 - g_loss: 2.0377 - kl_divergence: 0.3378 Epoch 79/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3852 - g_loss: 2.0344 - kl_divergence: 0.3414 Epoch 80/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3842 - g_loss: 2.0532 - kl_divergence: 0.3396 Epoch 81/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3803 - g_loss: 2.0818 - kl_divergence: 0.3402 Epoch 82/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3810 - g_loss: 2.0800 - kl_divergence: 0.3394 Epoch 83/200 1562/1562 [==============================] - 59s 38ms/step - d_loss: 0.3790 - g_loss: 2.0881 - kl_divergence: 0.3389 Epoch 84/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3766 - g_loss: 2.1098 - kl_divergence: 0.3380 Epoch 85/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3734 - g_loss: 2.1341 - kl_divergence: 0.3416 Epoch 86/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3699 - g_loss: 2.1488 - kl_divergence: 0.3401 Epoch 87/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3713 - g_loss: 2.1559 - kl_divergence: 0.3403 Epoch 88/200 1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.3664 - g_loss: 2.1840 - kl_divergence: 0.3409 Epoch 89/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3697 - g_loss: 2.1760 - kl_divergence: 0.3404 Epoch 90/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3644 - g_loss: 2.1908 - kl_divergence: 0.3416 Epoch 91/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3671 - g_loss: 2.2437 - kl_divergence: 0.3387 Epoch 92/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3607 - g_loss: 2.2272 - kl_divergence: 0.3440 Epoch 93/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3568 - g_loss: 2.2641 - kl_divergence: 0.3428 Epoch 94/200 1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.3582 - g_loss: 2.2703 - kl_divergence: 0.3453 Epoch 95/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3574 - g_loss: 2.2715 - kl_divergence: 0.3436 Epoch 96/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3563 - g_loss: 2.2839 - kl_divergence: 0.3463 Epoch 97/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3569 - g_loss: 2.3054 - kl_divergence: 0.3418 Epoch 98/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3517 - g_loss: 2.3075 - kl_divergence: 0.3440 Epoch 99/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3490 - g_loss: 2.3251 - kl_divergence: 0.3416 Epoch 100/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3461 - g_loss: 2.3574 - kl_divergence: 0.3486 Epoch 101/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3450 - g_loss: 2.3640 - kl_divergence: 0.3458 Epoch 102/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3435 - g_loss: 2.3998 - kl_divergence: 0.3447 Epoch 103/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3401 - g_loss: 2.4136 - kl_divergence: 0.3430 Epoch 104/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3394 - g_loss: 2.4441 - kl_divergence: 0.3462 Epoch 105/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3380 - g_loss: 2.4317 - kl_divergence: 0.3426 Epoch 106/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3369 - g_loss: 2.4472 - kl_divergence: 0.3467 Epoch 107/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3329 - g_loss: 2.4941 - kl_divergence: 0.3446 Epoch 108/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3307 - g_loss: 2.4686 - kl_divergence: 0.3431 Epoch 109/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3322 - g_loss: 2.5111 - kl_divergence: 0.3408 Epoch 110/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3535 - g_loss: 2.5937 - kl_divergence: 0.3419 Epoch 111/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3345 - g_loss: 2.4697 - kl_divergence: 0.3422 Epoch 112/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3301 - g_loss: 2.4993 - kl_divergence: 0.3429 Epoch 113/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3268 - g_loss: 2.5677 - kl_divergence: 0.3439 Epoch 114/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3242 - g_loss: 2.5815 - kl_divergence: 0.3412 Epoch 115/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3227 - g_loss: 2.5716 - kl_divergence: 0.3407 Epoch 116/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3230 - g_loss: 2.6080 - kl_divergence: 0.3451 Epoch 117/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3246 - g_loss: 2.6098 - kl_divergence: 0.3413 Epoch 118/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3198 - g_loss: 2.6258 - kl_divergence: 0.3451 Epoch 119/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3157 - g_loss: 2.6524 - kl_divergence: 0.3459 Epoch 120/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3141 - g_loss: 2.6667 - kl_divergence: 0.3419 Epoch 121/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3142 - g_loss: 2.6944 - kl_divergence: 0.3455 Epoch 122/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3134 - g_loss: 2.7040 - kl_divergence: 0.3468 Epoch 123/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3127 - g_loss: 2.7053 - kl_divergence: 0.3430 Epoch 124/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3095 - g_loss: 2.7472 - kl_divergence: 0.3472 Epoch 125/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3061 - g_loss: 2.7704 - kl_divergence: 0.3410 Epoch 126/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3128 - g_loss: 2.8328 - kl_divergence: 0.3419 Epoch 127/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3044 - g_loss: 2.7802 - kl_divergence: 0.3472 Epoch 128/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3051 - g_loss: 2.7825 - kl_divergence: 0.3439 Epoch 129/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3062 - g_loss: 2.7921 - kl_divergence: 0.3428 Epoch 130/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3003 - g_loss: 2.8314 - kl_divergence: 0.3449 Epoch 131/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3008 - g_loss: 2.8626 - kl_divergence: 0.3458 Epoch 132/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.3001 - g_loss: 2.8366 - kl_divergence: 0.3452 Epoch 133/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.2980 - g_loss: 2.8546 - kl_divergence: 0.3435 Epoch 134/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2993 - g_loss: 2.8737 - kl_divergence: 0.3431 Epoch 135/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2981 - g_loss: 2.8804 - kl_divergence: 0.3456 Epoch 136/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3016 - g_loss: 2.8909 - kl_divergence: 0.3435 Epoch 137/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2963 - g_loss: 2.8966 - kl_divergence: 0.3467 Epoch 138/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2950 - g_loss: 2.9014 - kl_divergence: 0.3451 Epoch 139/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2954 - g_loss: 2.9124 - kl_divergence: 0.3446 Epoch 140/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2930 - g_loss: 2.9481 - kl_divergence: 0.3474 Epoch 141/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2938 - g_loss: 2.9328 - kl_divergence: 0.3456 Epoch 142/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2876 - g_loss: 2.9911 - kl_divergence: 0.3445 Epoch 143/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2889 - g_loss: 3.0045 - kl_divergence: 0.3457 Epoch 144/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2884 - g_loss: 2.9893 - kl_divergence: 0.3494 Epoch 145/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2863 - g_loss: 3.0292 - kl_divergence: 0.3454 Epoch 146/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2811 - g_loss: 3.0429 - kl_divergence: 0.3468 Epoch 147/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2841 - g_loss: 3.0720 - kl_divergence: 0.3456 Epoch 148/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2819 - g_loss: 3.0755 - kl_divergence: 0.3500 Epoch 149/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2839 - g_loss: 3.0660 - kl_divergence: 0.3517 Epoch 150/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2810 - g_loss: 3.0974 - kl_divergence: 0.3483 Epoch 151/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2808 - g_loss: 3.1015 - kl_divergence: 0.3474 Epoch 152/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2794 - g_loss: 3.1128 - kl_divergence: 0.3438 Epoch 153/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2734 - g_loss: 3.1639 - kl_divergence: 0.3466 Epoch 154/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2796 - g_loss: 3.1662 - kl_divergence: 0.3450 Epoch 155/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2796 - g_loss: 3.1588 - kl_divergence: 0.3458 Epoch 156/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2716 - g_loss: 3.1673 - kl_divergence: 0.3470 Epoch 157/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2726 - g_loss: 3.2010 - kl_divergence: 0.3474 Epoch 158/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2777 - g_loss: 3.2242 - kl_divergence: 0.3471 Epoch 159/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2729 - g_loss: 3.2318 - kl_divergence: 0.3468 Epoch 160/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2733 - g_loss: 3.2283 - kl_divergence: 0.3461 Epoch 161/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2734 - g_loss: 3.2335 - kl_divergence: 0.3458 Epoch 162/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2706 - g_loss: 3.2452 - kl_divergence: 0.3446 Epoch 163/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2691 - g_loss: 3.2610 - kl_divergence: 0.3454 Epoch 164/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2688 - g_loss: 3.2782 - kl_divergence: 0.3470 Epoch 165/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2725 - g_loss: 3.2630 - kl_divergence: 0.3488 Epoch 166/200 1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.2720 - g_loss: 3.2585 - kl_divergence: 0.3461 Epoch 167/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2697 - g_loss: 3.3011 - kl_divergence: 0.3495 Epoch 168/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2679 - g_loss: 3.2953 - kl_divergence: 0.3490 Epoch 169/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2714 - g_loss: 3.2939 - kl_divergence: 0.3484 Epoch 170/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2682 - g_loss: 3.3109 - kl_divergence: 0.3518 Epoch 171/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2687 - g_loss: 3.3092 - kl_divergence: 0.3532 Epoch 172/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2688 - g_loss: 3.3288 - kl_divergence: 0.3484 Epoch 173/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2662 - g_loss: 3.3398 - kl_divergence: 0.3483 Epoch 174/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2673 - g_loss: 3.3607 - kl_divergence: 0.3484 Epoch 175/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2629 - g_loss: 3.3719 - kl_divergence: 0.3501 Epoch 176/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2671 - g_loss: 3.3627 - kl_divergence: 0.3453 Epoch 177/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2626 - g_loss: 3.4028 - kl_divergence: 0.3446 Epoch 178/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2621 - g_loss: 3.4167 - kl_divergence: 0.3461 Epoch 179/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2632 - g_loss: 3.3936 - kl_divergence: 0.3508 Epoch 180/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2650 - g_loss: 3.3966 - kl_divergence: 0.3462 Epoch 181/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2631 - g_loss: 3.4136 - kl_divergence: 0.3490 Epoch 182/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2567 - g_loss: 3.4415 - kl_divergence: 0.3490 Epoch 183/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2609 - g_loss: 3.4225 - kl_divergence: 0.3505 Epoch 184/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2638 - g_loss: 3.4268 - kl_divergence: 0.3471 Epoch 185/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2574 - g_loss: 3.4730 - kl_divergence: 0.3484 Epoch 186/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2633 - g_loss: 3.4415 - kl_divergence: 0.3492 Epoch 187/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2632 - g_loss: 3.4339 - kl_divergence: 0.3469 Epoch 188/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2581 - g_loss: 3.4664 - kl_divergence: 0.3471 Epoch 189/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2596 - g_loss: 3.4776 - kl_divergence: 0.3480 Epoch 190/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2572 - g_loss: 3.4819 - kl_divergence: 0.3475 Epoch 191/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2612 - g_loss: 3.4703 - kl_divergence: 0.3525 Epoch 192/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2571 - g_loss: 3.4966 - kl_divergence: 0.3492 Epoch 193/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2538 - g_loss: 3.5485 - kl_divergence: 0.3495 Epoch 194/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2518 - g_loss: 3.5523 - kl_divergence: 0.3489 Epoch 195/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2567 - g_loss: 3.5505 - kl_divergence: 0.3480 Epoch 196/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2543 - g_loss: 3.5229 - kl_divergence: 0.3476 Epoch 197/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2543 - g_loss: 3.5500 - kl_divergence: 0.3482 Epoch 198/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2533 - g_loss: 3.5558 - kl_divergence: 0.3512 Epoch 199/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2536 - g_loss: 3.5719 - kl_divergence: 0.3525 Epoch 200/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2524 - g_loss: 3.5922 - kl_divergence: 0.3531
<keras.callbacks.History at 0x22a08b19790>
input_folder = './output/models/cgan/generated/' # Replace with your frames directory
output_file = 'cgan_output_video.mp4' # Replace with your desired output file path
gnnf.create_video_from_frames(input_folder, output_file)
WGAN is an advanced type of Generative Adversarial Network (GAN) architecture that addresses some of the training instability issues faced by standard GANs, which uses a different loss function and training approach compared to the original GAN formulation.
WGANs offer several advantages, including:
Like all GANs, WGANs have two main components:
Unlike standard GANs, the critic in WGANs doesn't directly classify samples as real or fake. Instead, it estimates the "distance" between the real and generated data distributions using the Wasserstein distance.
The generator and critic are trained in an adversarial way:
The generator aims to minimize the Wasserstein distance, effectively fooling the critic into believing its generated data is real. The critic aims to maximize the Wasserstein distance, accurately distinguishing real and generated data. Over training, the generator gets better at creating realistic data, while the critic improves its ability to discriminate. This competition leads to improved quality and diversity in the generated data.
The formula for Wasserstein loss is as follows: \begin{align*} &Wasserstein \, Loss = \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot f(x_i) - f(g(z_i)) \right] \\ &\text{where}\\ &N \text{ is the number of observations,} \\ &y_i \text{ is the label indicating real (1) or generated (-1) for the } i^{\text{th}} \text{ observation,} \\ &f \text{ is the critic (or discriminator) network's output,} \\ &x_i \text{ is the real data instance,} \\ &g(z_i) \text{ is the generated data instance from noise vector } z_i. \end{align*}
def critic_loss(real_output, fake_output):
return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
def generator_loss(fake_output):
return -tf.reduce_mean(fake_output)
from tensorflow.keras.constraints import Constraint
class ClipConstraint(Constraint):
def __init__(self, clip_value):
self.clip_value = clip_value
def __call__(self, weights):
return tf.clip_by_value(weights, -self.clip_value, self.clip_value)
num_classes = 10
class wGAN(GAN_template):
def __init__(self, latent_dim):
super().__init__(latent_dim)
self.num_classes = num_classes
self.CRITIC_UPDATES = 5
def define_discriminator(self, in_shape=(32,32,3)):
# Image input
constraint = ClipConstraint(0.01)
image_input = Input(shape=in_shape)
# Label input and embedding
label_input = Input(shape=(1,))
label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
label_embedding = Dense(np.prod(in_shape))(label_embedding)
label_embedding = Reshape(in_shape)(label_embedding)
# Concatenate image and label
concatenated = Concatenate()([image_input, label_embedding])
# Discriminator model
x = Conv2D(64, (3,3), padding='same', kernel_constraint=constraint)(concatenated)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
x = LeakyReLU(alpha=0.2)(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
output = Dense(1)(x)
# Define and compile model
model = Model(inputs=[image_input, label_input], outputs=output)
model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
return model
def define_generator(self, latent_dim):
model = Sequential()
# Model for processing the labels
label_input = tf.keras.Input(shape=(1,), dtype='int32')
label_embedding = Embedding(num_classes, latent_dim)(label_input)
label_embedding = Flatten()(label_embedding)
# Model for processing the latent vector
latent_input = tf.keras.Input(shape=(latent_dim,))
# Combine label and latent inputs
merged_input = Concatenate()([latent_input, label_embedding])
# Sequential model for the generator
generator = Sequential([
Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
LeakyReLU(alpha=0.2),
Reshape((8, 8, 256)),
Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
])
# Pass the merged input through the generator model
generated_image = generator(merged_input)
# Final cGAN generator model
model = Model(inputs=[latent_input, label_input], outputs=generated_image)
return model
def train_step(self, data):
# Unpack the data
if isinstance(data, tuple):
real_images, real_labels = data
else:
real_images = data
real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)
# Critic updates
for _ in range(self.CRITIC_UPDATES):
with tf.GradientTape() as tape:
generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
real_output = self.discriminator([real_images, real_labels], training=True)
fake_output = self.discriminator([generated_images, fake_labels], training=True)
c_loss = critic_loss(real_output, fake_output)
c_grads = tape.gradient(c_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(c_grads, self.discriminator.trainable_weights))
with tf.GradientTape() as tape:
generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
fake_output = self.discriminator([generated_images, fake_labels], training=True)
g_loss = generator_loss(fake_output) # Ensure this is a suitable loss for WGAN
kl_loss = self.kl_divergence(real_images, generated_images)
g_grads = tape.gradient(g_loss, self.generator.trainable_weights) # Include KL divergence in gradients
self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_weights))
# Update metrics
self.d_loss_tracker.update_state(c_loss)
self.g_loss_tracker.update_state(g_loss)
self.kl_divergence_tracker.update_state(kl_loss)
return{
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
"kl_divergence": self.kl_divergence_tracker.result()
}
@staticmethod
def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
examples = (examples + 1) / 2.0
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
for i in range(10*5): # 3 images per class, 10 classes
class_index = i // 5 # Determine class based on order
ax = fig.add_subplot(gs[i % 5, class_index])
# print(i % 5, class_index)
ax.axis('off')
ax.imshow(examples[i])
# Add class label text for the first image of each class
if i % 5 == 0:
ax.set_title(class_names[class_index], fontsize=8)
# Plot for discriminator losses
ax_loss = fig.add_subplot(gs[5:8, 0:3])
ax_loss.plot(d_losses, label="Discriminator Loss")
ax_loss.set_title("Discriminator Loss")
# Plot for generator losses
ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
ax_g_loss.plot(g_losses, label="Generator Loss")
ax_g_loss.set_title("Generator Loss")
ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
ax_kl_div.plot(kl_div, label="KL Divergence")
ax_kl_div.set_title("KL Divergence")
plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
plt.tight_layout()
plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
plt.close()
@staticmethod
def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
X, y = [], []
for class_label in range(10): # CIFAR-10 has 10 classes
# Generate latent points
x_input = np.random.randn(latent_dim * n_samples)
x_input = x_input.reshape(n_samples, latent_dim)
# Create class labels
labels = np.full((n_samples, 1), class_label)
# Generate images
images = generator.predict([x_input, labels], verbose=0)
X.extend(images)
y.extend(labels)
return np.asarray(X), np.asarray(y)
wgan = wGAN(latent_dim=100)
wgan.compile(
d_optimizer=RMSprop(learning_rate=0.0003),
g_optimizer=RMSprop(learning_rate=0.0003),
loss_fn=BinaryCrossentropy(from_logits=True),
)
wgan_callback = CustomCallback(d_losses = wgan.d_loss_list, g_losses = wgan.g_loss_list, kl_div=wgan.kl_div_list, model = wgan, filepath = "output/models/wgan/")
wgan.fit(X_train, epochs = 50, callbacks = [wgan_callback])
Epoch 1/50 1562/1562 [==============================] - 151s 92ms/step - d_loss: -2511.8191 - g_loss: 11509.7070 - kl_divergence: 0.5411 Epoch 2/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -38174.3750 - g_loss: 75383.7422 - kl_divergence: 0.5661 Epoch 3/50 1562/1562 [==============================] - 143s 91ms/step - d_loss: -53997.6914 - g_loss: 155898.1562 - kl_divergence: 0.5868 Epoch 4/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -174930.6094 - g_loss: -50041.0430 - kl_divergence: 0.6316 Epoch 5/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -541265.0625 - g_loss: -2412482.0000 - kl_divergence: 0.5908 Epoch 6/50 1562/1562 [==============================] - 143s 92ms/step - d_loss: -730157.8125 - g_loss: 77811.3906 - kl_divergence: 0.5558 Epoch 7/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -1167884.8750 - g_loss: -5160780.0000 - kl_divergence: 0.5474 Epoch 8/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -1581006.0000 - g_loss: 1626707.7500 - kl_divergence: 0.5673 Epoch 9/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -2397018.2500 - g_loss: -2872515.0000 - kl_divergence: 0.5713 Epoch 10/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -2372221.0000 - g_loss: -9411951.0000 - kl_divergence: 0.5423 Epoch 11/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -3415741.5000 - g_loss: -6726378.0000 - kl_divergence: 0.5977 Epoch 12/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -4656259.0000 - g_loss: -18112762.0000 - kl_divergence: 0.5754 Epoch 13/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -7106436.0000 - g_loss: 32966476.0000 - kl_divergence: 0.5420 Epoch 14/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -11302557.0000 - g_loss: 61351872.0000 - kl_divergence: 0.5448 Epoch 15/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -17009860.0000 - g_loss: 67873368.0000 - kl_divergence: 0.4710 Epoch 16/50 1562/1562 [==============================] - 143s 91ms/step - d_loss: -11430128.0000 - g_loss: 26108232.0000 - kl_divergence: 0.4775 Epoch 17/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -20488420.0000 - g_loss: -12855508.0000 - kl_divergence: 0.5432 Epoch 18/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -127779392.0000 - g_loss: -469266592.0000 - kl_divergence: 0.5487 Epoch 19/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -120902504.0000 - g_loss: -562537280.0000 - kl_divergence: 0.5074 Epoch 20/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -54040672.0000 - g_loss: -150013616.0000 - kl_divergence: 0.4903 Epoch 21/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -31215888.0000 - g_loss: -30920370.0000 - kl_divergence: 0.4663 Epoch 22/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -133609240.0000 - g_loss: -782933440.0000 - kl_divergence: 0.4160 Epoch 23/50 1562/1562 [==============================] - 143s 91ms/step - d_loss: -188014512.0000 - g_loss: -852295680.0000 - kl_divergence: 0.4265 Epoch 24/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -139411168.0000 - g_loss: -618659520.0000 - kl_divergence: 0.4141 Epoch 25/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -48559948.0000 - g_loss: -223021840.0000 - kl_divergence: 0.4052 Epoch 26/50 1562/1562 [==============================] - 142s 91ms/step - d_loss: -68058672.0000 - g_loss: 60360144.0000 - kl_divergence: 0.4080 Epoch 27/50 1562/1562 [==============================] - 140s 89ms/step - d_loss: -302390752.0000 - g_loss: -1766296064.0000 - kl_divergence: 0.4210 Epoch 28/50 1562/1562 [==============================] - 140s 90ms/step - d_loss: -216782304.0000 - g_loss: 1197224320.0000 - kl_divergence: 0.4133 Epoch 29/50 1562/1562 [==============================] - 140s 89ms/step - d_loss: -298702592.0000 - g_loss: 1927712000.0000 - kl_divergence: 0.4263 Epoch 30/50 1562/1562 [==============================] - 140s 90ms/step - d_loss: -372681088.0000 - g_loss: 2415065088.0000 - kl_divergence: 0.4424 Epoch 31/50 1562/1562 [==============================] - 141s 90ms/step - d_loss: -416613312.0000 - g_loss: 2832033280.0000 - kl_divergence: 0.4723 Epoch 32/50 1562/1562 [==============================] - 140s 89ms/step - d_loss: -461932640.0000 - g_loss: 159228352.0000 - kl_divergence: 0.4745 Epoch 33/50 1562/1562 [==============================] - 140s 90ms/step - d_loss: -571692096.0000 - g_loss: -3261810176.0000 - kl_divergence: 0.4536 Epoch 34/50 1562/1562 [==============================] - 148s 94ms/step - d_loss: -488175552.0000 - g_loss: -2884788224.0000 - kl_divergence: 0.4420 Epoch 35/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -768050432.0000 - g_loss: -3712435456.0000 - kl_divergence: 0.4111 Epoch 36/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -858683200.0000 - g_loss: -4325497344.0000 - kl_divergence: 0.3874 Epoch 37/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -818794432.0000 - g_loss: -4564397056.0000 - kl_divergence: 0.3683 Epoch 38/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -169840064.0000 - g_loss: -301608896.0000 - kl_divergence: 0.3777 Epoch 39/50 1562/1562 [==============================] - 153s 98ms/step - d_loss: -201049216.0000 - g_loss: -12258592.0000 - kl_divergence: 0.3465 Epoch 40/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -197536288.0000 - g_loss: 1037425536.0000 - kl_divergence: 0.3395 Epoch 41/50 1562/1562 [==============================] - 154s 99ms/step - d_loss: -211085920.0000 - g_loss: 39040424.0000 - kl_divergence: 0.3374 Epoch 42/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -204623872.0000 - g_loss: -304704864.0000 - kl_divergence: 0.3552 Epoch 43/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -295322240.0000 - g_loss: -1090108672.0000 - kl_divergence: 0.3444 Epoch 44/50 1562/1562 [==============================] - 153s 98ms/step - d_loss: -284562528.0000 - g_loss: -893282368.0000 - kl_divergence: 0.3511 Epoch 45/50 1562/1562 [==============================] - 150s 96ms/step - d_loss: -346499520.0000 - g_loss: -679465408.0000 - kl_divergence: 0.3725 Epoch 46/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -1048624512.0000 - g_loss: 4641689088.0000 - kl_divergence: 0.3748 Epoch 47/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -1603169536.0000 - g_loss: 8545432576.0000 - kl_divergence: 0.3751 Epoch 48/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -376864448.0000 - g_loss: 88227160.0000 - kl_divergence: 0.3829 Epoch 49/50 1562/1562 [==============================] - 152s 97ms/step - d_loss: -440872544.0000 - g_loss: -1773806464.0000 - kl_divergence: 0.3780 Epoch 50/50 1562/1562 [==============================] - 153s 98ms/step - d_loss: -1162849024.0000 - g_loss: 7700485632.0000 - kl_divergence: 0.3899
<keras.callbacks.History at 0x1a18e4d76a0>
However, this implementation of wGAN seems to be ineffective, with the model crashing even after serveral attempts to tune it properly. Hence, we shall state that for the given architecture, using wGAN is ineffective, and does not work.
Earlier on, we attempted (but failed) to use wGAN to solve the task at hand. Now, we shall try an alternative model to try to solve the given task. for this, we shall use Hinge GAN.
Hinge GANs are a type of Generative Adversarial Network (GAN) that utilize a "hinge loss" function to train their discriminator, the component responsible for judging real versus fake data. Unlike the standard binary cross-entropy loss in traditional GANs, hinge loss only penalizes the discriminator when it makes mistakes or fails to confidently distinguish between real and generated data. This approach offers several benefits:
Improved Stability: Hinge loss avoids the vanishing gradient problem that can plague traditional GANs during training. This leads to smoother training and potentially faster convergence to stable models.
Better Focus on Margins: By focusing on pushing real and fake data apart in the scoring space, hinge loss encourages the discriminator to pay more attention to the quality of generated data rather than just classifying them correctly. This can lead to sharper and more realistic generations.
However, hinge loss also comes with drawbacks. Its emphasis on margins can sometimes lead to "mode collapse," where the generator gets stuck producing only a limited variety of outputs. Carefully choosing hyperparameters and architectures can help mitigate this issue.
Thus, we use this model to hopefully stabilize the model, as well as to generate higher quality, realistic, and detailed data.
The formula for hinge loss is as follows:
\begin{align*} &D(x) = \max(0, 1 - t \cdot f(x)) \\ &\text{where} \\ &t = \begin{cases} 1 & \text{for real data (positive class)} \\ -1 & \text{for generated data (negative class)} \end{cases} \\ &f(x) \text{ is the discriminator's output for an input } x. \end{align*}While in a traditional Hinge GAN, one might use both hinge loss for both the generator and discriminator, it may not yield the best results for the network. This is as Hinge loss, as its name suggests, focuses on pushing real and fake data apart in the scoring space. While this is beneficial for the discriminator to learn to tell them apart, it's not directly relevant to the generator's goal of creating realistic data. Penalizing the generator based on the discriminator's confidence, even for good outputs, can hinder its learning and potentially lead to suboptimal results.
num_classes = 10
class hingeGAN(GAN_template):
def __init__(self, latent_dim):
super().__init__(latent_dim)
self.num_classes = num_classes
def define_discriminator(self, in_shape=(32,32,3)):
# Image input
image_input = Input(shape=in_shape)
# Label input and embedding
label_input = Input(shape=(1,))
label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
label_embedding = Dense(np.prod(in_shape))(label_embedding)
label_embedding = Reshape(in_shape)(label_embedding)
# Concatenate image and label
concatenated = Concatenate()([image_input, label_embedding])
# Discriminator model
x = Conv2D(64, (3,3), padding='same')(concatenated)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.001, l2=0.001))(x)
x = LeakyReLU(alpha=0.2)(x)
x = Flatten()(x)
x = Dropout(0.4)(x)
output = Dense(1, activation='linear')(x)
# Define and compile model
model = Model(inputs=[image_input, label_input], outputs=output)
return model
def define_generator(self, latent_dim):
model = Sequential()
# Model for processing the labels
label_input = tf.keras.Input(shape=(1,), dtype='int32')
label_embedding = Embedding(num_classes, latent_dim)(label_input)
label_embedding = Flatten()(label_embedding)
# Model for processing the latent vector
latent_input = tf.keras.Input(shape=(latent_dim,))
# Combine label and latent inputs
merged_input = Concatenate()([latent_input, label_embedding])
# Sequential model for the generator
generator = Sequential([
Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
LeakyReLU(alpha=0.2),
Reshape((8, 8, 256)),
Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
BatchNormalization(),
LeakyReLU(alpha=0.2),
Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
])
# Pass the merged input through the generator model
generated_image = generator(merged_input)
# Final cGAN generator model
model = Model(inputs=[latent_input, label_input], outputs=generated_image)
return model
def train_step(self, data):
# Unpack the data. Its structure depends on your dataset and
# whether it includes labels
if isinstance(data, tuple):
real_images, real_labels = data
else:
real_images = data
real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)
batch_size = real_images[0].shape[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# real_images = (real_images - 127.5) / 127.5 # Normalize to [-1, 1] if your real_images are in [0, 255]
real_images = tf.reshape(real_images, [batch_size, 32, 32, 3])
# Generate labels for fake images if needed
fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)
# Generate fake images
generated_images = self.generator([random_latent_vectors, fake_labels])
combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)
real_labels = tf.squeeze(real_labels)
combined_labels = tf.concat([tf.cast(fake_labels, 'uint8'), real_labels], axis=0) # Concatenate labels as well
# Labels for discriminator to discriminate real from fake images
discriminator_labels = tf.concat(
[tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
)
# Train the discriminator
with tf.GradientTape() as tape:
predictions_on_real = self.discriminator([real_images, real_labels])
predictions_on_fake = self.discriminator([generated_images, fake_labels])
# Hinge loss for the discriminator
d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - predictions_on_real))
d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + predictions_on_fake))
d_loss = d_loss_real + d_loss_fake
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
# Misleading labels for the generator
misleading_labels = tf.ones((batch_size, 1))
# Train the generator
with tf.GradientTape() as tape:
generated_images = self.generator([random_latent_vectors, fake_labels])
predictions = self.discriminator([generated_images, fake_labels])
g_loss = -tf.reduce_mean(predictions)
kl_loss = self.kl_divergence(real_images, generated_images)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
# Update metrics
self.d_loss_tracker.update_state(d_loss)
self.g_loss_tracker.update_state(g_loss)
self.kl_divergence_tracker.update_state(kl_loss)
return {
"d_loss": self.d_loss_tracker.result(),
"g_loss": self.g_loss_tracker.result(),
"kl_divergence": self.kl_divergence_tracker.result()
}
@staticmethod
def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
examples = (examples + 1) / 2.0
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
for i in range(10*5): # 3 images per class, 10 classes
class_index = i // 5 # Determine class based on order
ax = fig.add_subplot(gs[i % 5, class_index])
# print(i % 5, class_index)
ax.axis('off')
ax.imshow(examples[i])
# Add class label text for the first image of each class
if i % 5 == 0:
ax.set_title(class_names[class_index], fontsize=8)
# Plot for discriminator losses
ax_loss = fig.add_subplot(gs[5:8, 0:3])
ax_loss.plot(d_losses, label="Discriminator Loss")
ax_loss.set_title("Discriminator Loss")
# Plot for generator losses
ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
ax_g_loss.plot(g_losses, label="Generator Loss")
ax_g_loss.set_title("Generator Loss")
ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
ax_kl_div.plot(kl_div, label="KL Divergence")
ax_kl_div.set_title("KL Divergence")
plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
plt.tight_layout()
plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
plt.close()
@staticmethod
def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
X, y = [], []
for class_label in range(10): # CIFAR-10 has 10 classes
# Generate latent points
x_input = np.random.randn(latent_dim * n_samples)
x_input = x_input.reshape(n_samples, latent_dim)
# Create class labels
labels = np.full((n_samples, 1), class_label)
# Generate images
images = generator.predict([x_input, labels], verbose=0)
X.extend(images)
y.extend(labels)
return np.asarray(X), np.asarray(y)
hinge_gan = hingeGAN(latent_dim=100)
hinge_gan.compile(
d_optimizer=Adam(learning_rate=0.0003),
g_optimizer=Adam(learning_rate=0.0003),
loss_fn=BinaryCrossentropy(),
)
hinge_gan_callback = CustomCallback(d_losses = hinge_gan.d_loss_list, g_losses = hinge_gan.g_loss_list, kl_div=hinge_gan.kl_div_list, model = hinge_gan, filepath = "output/models/hinge_gan/")
hinge_gan.fit(X_train, epochs = 200, callbacks = [hinge_gan_callback])
Epoch 1/200 1562/1562 [==============================] - 52s 32ms/step - d_loss: 1.4068 - g_loss: 1.1703 - kl_divergence: 0.7009 Epoch 2/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.2997 - g_loss: 1.1075 - kl_divergence: 0.5582 Epoch 3/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.5804 - g_loss: 0.7518 - kl_divergence: 0.5271 Epoch 4/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.4858 - g_loss: 0.8374 - kl_divergence: 0.4826 Epoch 5/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3063 - g_loss: 1.1355 - kl_divergence: 0.5613 Epoch 6/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5198 - g_loss: 0.8947 - kl_divergence: 0.4339 Epoch 7/200 1562/1562 [==============================] - 54s 34ms/step - d_loss: 1.4658 - g_loss: 0.8365 - kl_divergence: 0.4356 Epoch 8/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6505 - g_loss: 0.5848 - kl_divergence: 0.3753 Epoch 9/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6594 - g_loss: 0.6821 - kl_divergence: 0.3933 Epoch 10/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6946 - g_loss: 0.5013 - kl_divergence: 0.3628 Epoch 11/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.7023 - g_loss: 0.5359 - kl_divergence: 0.3619 Epoch 12/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6076 - g_loss: 0.7303 - kl_divergence: 0.3474 Epoch 13/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5535 - g_loss: 0.7393 - kl_divergence: 0.3444 Epoch 14/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.4983 - g_loss: 0.7002 - kl_divergence: 0.3393 Epoch 15/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5308 - g_loss: 0.6908 - kl_divergence: 0.3411 Epoch 16/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.5017 - g_loss: 0.6909 - kl_divergence: 0.3418 Epoch 17/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3154 - g_loss: 0.8463 - kl_divergence: 0.3460 Epoch 18/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1563 - g_loss: 1.0569 - kl_divergence: 0.3702 Epoch 19/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1680 - g_loss: 1.0744 - kl_divergence: 0.3824 Epoch 20/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1684 - g_loss: 1.0503 - kl_divergence: 0.3869 Epoch 21/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.2071 - g_loss: 0.9494 - kl_divergence: 0.4044 Epoch 22/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.2979 - g_loss: 0.8252 - kl_divergence: 0.3918 Epoch 23/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3386 - g_loss: 0.7789 - kl_divergence: 0.3829 Epoch 24/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3664 - g_loss: 0.7522 - kl_divergence: 0.3735 Epoch 25/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3652 - g_loss: 0.7219 - kl_divergence: 0.3888 Epoch 26/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3442 - g_loss: 0.7601 - kl_divergence: 0.3748 Epoch 27/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3432 - g_loss: 0.7415 - kl_divergence: 0.3680 Epoch 28/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3300 - g_loss: 0.7771 - kl_divergence: 0.3579 Epoch 29/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.2607 - g_loss: 0.8332 - kl_divergence: 0.3576 Epoch 30/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.2657 - g_loss: 0.8211 - kl_divergence: 0.3631 Epoch 31/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.2001 - g_loss: 0.8541 - kl_divergence: 0.3567 Epoch 32/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1823 - g_loss: 0.8885 - kl_divergence: 0.3626 Epoch 33/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1570 - g_loss: 0.8930 - kl_divergence: 0.3535 Epoch 34/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1634 - g_loss: 0.8873 - kl_divergence: 0.3482 Epoch 35/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1483 - g_loss: 0.8892 - kl_divergence: 0.3487 Epoch 36/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1788 - g_loss: 0.8548 - kl_divergence: 0.3495 Epoch 37/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1767 - g_loss: 0.8683 - kl_divergence: 0.3467 Epoch 38/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2036 - g_loss: 0.8460 - kl_divergence: 0.3521 Epoch 39/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2442 - g_loss: 0.8155 - kl_divergence: 0.3522 Epoch 40/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.2047 - g_loss: 0.8271 - kl_divergence: 0.3573 Epoch 41/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2177 - g_loss: 0.8495 - kl_divergence: 0.3535 Epoch 42/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2080 - g_loss: 0.8276 - kl_divergence: 0.3514 Epoch 43/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2192 - g_loss: 0.8659 - kl_divergence: 0.3538 Epoch 44/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2127 - g_loss: 0.8250 - kl_divergence: 0.3563 Epoch 45/200 1562/1562 [==============================] - 49s 32ms/step - d_loss: 1.2102 - g_loss: 0.8466 - kl_divergence: 0.3498 Epoch 46/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2119 - g_loss: 0.8302 - kl_divergence: 0.3522 Epoch 47/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2072 - g_loss: 0.8404 - kl_divergence: 0.3503 Epoch 48/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2030 - g_loss: 0.8301 - kl_divergence: 0.3504 Epoch 49/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2062 - g_loss: 0.8528 - kl_divergence: 0.3473 Epoch 50/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1961 - g_loss: 0.8442 - kl_divergence: 0.3496 Epoch 51/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2025 - g_loss: 0.8651 - kl_divergence: 0.3466 Epoch 52/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1947 - g_loss: 0.8362 - kl_divergence: 0.3529 Epoch 53/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1871 - g_loss: 0.8579 - kl_divergence: 0.3443 Epoch 54/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1739 - g_loss: 0.8694 - kl_divergence: 0.3435 Epoch 55/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2025 - g_loss: 0.8642 - kl_divergence: 0.3440 Epoch 56/200 1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1924 - g_loss: 0.8862 - kl_divergence: 0.3461 Epoch 57/200 1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1943 - g_loss: 0.8427 - kl_divergence: 0.3479 Epoch 58/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1770 - g_loss: 0.8710 - kl_divergence: 0.3549 Epoch 59/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1782 - g_loss: 0.8660 - kl_divergence: 0.3445 Epoch 60/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1712 - g_loss: 0.8715 - kl_divergence: 0.3423 Epoch 61/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1678 - g_loss: 0.8867 - kl_divergence: 0.3435 Epoch 62/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1596 - g_loss: 0.8751 - kl_divergence: 0.3466 Epoch 63/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1657 - g_loss: 0.8960 - kl_divergence: 0.3444 Epoch 64/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1642 - g_loss: 0.8928 - kl_divergence: 0.3463 Epoch 65/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1646 - g_loss: 0.8916 - kl_divergence: 0.3468 Epoch 66/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1600 - g_loss: 0.8949 - kl_divergence: 0.3432 Epoch 67/200 1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.1828 - g_loss: 0.9018 - kl_divergence: 0.3432 Epoch 68/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1559 - g_loss: 0.8802 - kl_divergence: 0.3491 Epoch 69/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1576 - g_loss: 0.9157 - kl_divergence: 0.3448 Epoch 70/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1659 - g_loss: 0.8853 - kl_divergence: 0.3429 Epoch 71/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1595 - g_loss: 0.9050 - kl_divergence: 0.3417 Epoch 72/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1586 - g_loss: 0.8963 - kl_divergence: 0.3431 Epoch 73/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1713 - g_loss: 0.9004 - kl_divergence: 0.3425 Epoch 74/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1631 - g_loss: 0.8904 - kl_divergence: 0.3463 Epoch 75/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1568 - g_loss: 0.9570 - kl_divergence: 0.3459 Epoch 76/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1544 - g_loss: 0.8804 - kl_divergence: 0.3461 Epoch 77/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1442 - g_loss: 0.9052 - kl_divergence: 0.3444 Epoch 78/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1503 - g_loss: 0.9182 - kl_divergence: 0.3437 Epoch 79/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1501 - g_loss: 0.9152 - kl_divergence: 0.3443 Epoch 80/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1457 - g_loss: 0.9098 - kl_divergence: 0.3465 Epoch 81/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1458 - g_loss: 0.9083 - kl_divergence: 0.3472 Epoch 82/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1446 - g_loss: 0.9175 - kl_divergence: 0.3492 Epoch 83/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1347 - g_loss: 0.9641 - kl_divergence: 0.3436 Epoch 84/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1482 - g_loss: 0.9012 - kl_divergence: 0.3442 Epoch 85/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1334 - g_loss: 0.9249 - kl_divergence: 0.3455 Epoch 86/200 1562/1562 [==============================] - 53s 34ms/step - d_loss: 1.1392 - g_loss: 0.9208 - kl_divergence: 0.3465 Epoch 87/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1310 - g_loss: 0.9467 - kl_divergence: 0.3433 Epoch 88/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1334 - g_loss: 0.9430 - kl_divergence: 0.3436 Epoch 89/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1438 - g_loss: 0.9275 - kl_divergence: 0.3511 Epoch 90/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1332 - g_loss: 0.9193 - kl_divergence: 0.3465 Epoch 91/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1261 - g_loss: 0.9217 - kl_divergence: 0.3454 Epoch 92/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1287 - g_loss: 0.9298 - kl_divergence: 0.3465 Epoch 93/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1289 - g_loss: 0.9622 - kl_divergence: 0.3455 Epoch 94/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1365 - g_loss: 0.9233 - kl_divergence: 0.3472 Epoch 95/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1166 - g_loss: 0.9496 - kl_divergence: 0.3462 Epoch 96/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1111 - g_loss: 0.9699 - kl_divergence: 0.3430 Epoch 97/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1102 - g_loss: 0.9419 - kl_divergence: 0.3431 Epoch 98/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1112 - g_loss: 0.9633 - kl_divergence: 0.3443 Epoch 99/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1169 - g_loss: 0.9709 - kl_divergence: 0.3436 Epoch 100/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1140 - g_loss: 0.9564 - kl_divergence: 0.3420 Epoch 101/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1138 - g_loss: 0.9648 - kl_divergence: 0.3471 Epoch 102/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1131 - g_loss: 0.9483 - kl_divergence: 0.3412 Epoch 103/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1061 - g_loss: 0.9601 - kl_divergence: 0.3441 Epoch 104/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1016 - g_loss: 0.9799 - kl_divergence: 0.3454 Epoch 105/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1091 - g_loss: 0.9835 - kl_divergence: 0.3436 Epoch 106/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1042 - g_loss: 0.9706 - kl_divergence: 0.3435 Epoch 107/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1091 - g_loss: 0.9762 - kl_divergence: 0.3441 Epoch 108/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1133 - g_loss: 0.9590 - kl_divergence: 0.3473 Epoch 109/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1072 - g_loss: 0.9735 - kl_divergence: 0.3448 Epoch 110/200 1562/1562 [==============================] - 54s 34ms/step - d_loss: 1.1011 - g_loss: 0.9659 - kl_divergence: 0.3434 Epoch 111/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0891 - g_loss: 0.9771 - kl_divergence: 0.3445 Epoch 112/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0943 - g_loss: 0.9774 - kl_divergence: 0.3464 Epoch 113/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0956 - g_loss: 0.9813 - kl_divergence: 0.3424 Epoch 114/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0931 - g_loss: 0.9837 - kl_divergence: 0.3460 Epoch 115/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0891 - g_loss: 1.0101 - kl_divergence: 0.3486 Epoch 116/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0906 - g_loss: 0.9853 - kl_divergence: 0.3543 Epoch 117/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0859 - g_loss: 0.9850 - kl_divergence: 0.3429 Epoch 118/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0752 - g_loss: 1.0073 - kl_divergence: 0.3436 Epoch 119/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0840 - g_loss: 0.9904 - kl_divergence: 0.3455 Epoch 120/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0879 - g_loss: 0.9863 - kl_divergence: 0.3484 Epoch 121/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0602 - g_loss: 1.0297 - kl_divergence: 0.3452 Epoch 122/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0681 - g_loss: 0.9994 - kl_divergence: 0.3433 Epoch 123/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0790 - g_loss: 1.0140 - kl_divergence: 0.3453 Epoch 124/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0661 - g_loss: 1.0104 - kl_divergence: 0.3429 Epoch 125/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0746 - g_loss: 1.0140 - kl_divergence: 0.3458 Epoch 126/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0631 - g_loss: 1.0134 - kl_divergence: 0.3437 Epoch 127/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0578 - g_loss: 1.0399 - kl_divergence: 0.3463 Epoch 128/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0543 - g_loss: 1.0314 - kl_divergence: 0.3422 Epoch 129/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0683 - g_loss: 1.0400 - kl_divergence: 0.3437 Epoch 130/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0517 - g_loss: 1.0381 - kl_divergence: 0.3462 Epoch 131/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0576 - g_loss: 1.0286 - kl_divergence: 0.3426 Epoch 132/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0572 - g_loss: 1.0345 - kl_divergence: 0.3432 Epoch 133/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0579 - g_loss: 1.0400 - kl_divergence: 0.3436 Epoch 134/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0512 - g_loss: 1.0413 - kl_divergence: 0.3452 Epoch 135/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0485 - g_loss: 1.0525 - kl_divergence: 0.3435 Epoch 136/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0488 - g_loss: 1.0618 - kl_divergence: 0.3425 Epoch 137/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0471 - g_loss: 1.0503 - kl_divergence: 0.3521 Epoch 138/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0446 - g_loss: 1.0619 - kl_divergence: 0.3450 Epoch 139/200 1562/1562 [==============================] - 55s 35ms/step - d_loss: 1.0402 - g_loss: 1.0503 - kl_divergence: 0.3445 Epoch 140/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0456 - g_loss: 1.0398 - kl_divergence: 0.3434 Epoch 141/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0426 - g_loss: 1.0680 - kl_divergence: 0.3460 Epoch 142/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0513 - g_loss: 1.0747 - kl_divergence: 0.3453 Epoch 143/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0383 - g_loss: 1.0625 - kl_divergence: 0.3447 Epoch 144/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0350 - g_loss: 1.0663 - kl_divergence: 0.3455 Epoch 145/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0323 - g_loss: 1.0645 - kl_divergence: 0.3459 Epoch 146/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0269 - g_loss: 1.0613 - kl_divergence: 0.3414 Epoch 147/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0286 - g_loss: 1.0912 - kl_divergence: 0.3431 Epoch 148/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0291 - g_loss: 1.0562 - kl_divergence: 0.3452 Epoch 149/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0133 - g_loss: 1.0851 - kl_divergence: 0.3460 Epoch 150/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0292 - g_loss: 1.0763 - kl_divergence: 0.3473 Epoch 151/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0199 - g_loss: 1.0885 - kl_divergence: 0.3437 Epoch 152/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0239 - g_loss: 1.0846 - kl_divergence: 0.3465 Epoch 153/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0202 - g_loss: 1.0812 - kl_divergence: 0.3435 Epoch 154/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0230 - g_loss: 1.0968 - kl_divergence: 0.3474 Epoch 155/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0128 - g_loss: 1.0996 - kl_divergence: 0.3483 Epoch 156/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0180 - g_loss: 1.0773 - kl_divergence: 0.3434 Epoch 157/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0084 - g_loss: 1.0954 - kl_divergence: 0.3450 Epoch 158/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0274 - g_loss: 1.0952 - kl_divergence: 0.3480 Epoch 159/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0117 - g_loss: 1.0851 - kl_divergence: 0.3566 Epoch 160/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0025 - g_loss: 1.1046 - kl_divergence: 0.3431 Epoch 161/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0130 - g_loss: 1.1085 - kl_divergence: 0.3425 Epoch 162/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0145 - g_loss: 1.1034 - kl_divergence: 0.3457 Epoch 163/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.0003 - g_loss: 1.1156 - kl_divergence: 0.3489 Epoch 164/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 0.9967 - g_loss: 1.1218 - kl_divergence: 0.3468 Epoch 165/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9979 - g_loss: 1.1281 - kl_divergence: 0.3450 Epoch 166/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.0051 - g_loss: 1.1207 - kl_divergence: 0.3417 Epoch 167/200 1562/1562 [==============================] - 51s 32ms/step - d_loss: 0.9906 - g_loss: 1.1343 - kl_divergence: 0.3428 Epoch 168/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9867 - g_loss: 1.1278 - kl_divergence: 0.3435 Epoch 169/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9873 - g_loss: 1.1408 - kl_divergence: 0.3460 Epoch 170/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9899 - g_loss: 1.1222 - kl_divergence: 0.3444 Epoch 171/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9945 - g_loss: 1.1174 - kl_divergence: 0.3499 Epoch 172/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0026 - g_loss: 1.1190 - kl_divergence: 0.3439 Epoch 173/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9868 - g_loss: 1.1336 - kl_divergence: 0.3441 Epoch 174/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9915 - g_loss: 1.1241 - kl_divergence: 0.3441 Epoch 175/200 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.9822 - g_loss: 1.1513 - kl_divergence: 0.3437 Epoch 176/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9804 - g_loss: 1.1462 - kl_divergence: 0.3556 Epoch 177/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9786 - g_loss: 1.1407 - kl_divergence: 0.3465 Epoch 178/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9720 - g_loss: 1.1447 - kl_divergence: 0.3450 Epoch 179/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9881 - g_loss: 1.1381 - kl_divergence: 0.3457 Epoch 180/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9783 - g_loss: 1.1500 - kl_divergence: 0.3439 Epoch 181/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9749 - g_loss: 1.1415 - kl_divergence: 0.3549 Epoch 182/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9770 - g_loss: 1.1448 - kl_divergence: 0.3448 Epoch 183/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9796 - g_loss: 1.1363 - kl_divergence: 0.3487 Epoch 184/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9818 - g_loss: 1.1650 - kl_divergence: 0.3474 Epoch 185/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9675 - g_loss: 1.1617 - kl_divergence: 0.3460 Epoch 186/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9752 - g_loss: 1.1621 - kl_divergence: 0.3438 Epoch 187/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9677 - g_loss: 1.1749 - kl_divergence: 0.3468 Epoch 188/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9676 - g_loss: 1.1823 - kl_divergence: 0.3456 Epoch 189/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9656 - g_loss: 1.1680 - kl_divergence: 0.3436 Epoch 190/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9589 - g_loss: 1.1822 - kl_divergence: 0.3466 Epoch 191/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9564 - g_loss: 1.1801 - kl_divergence: 0.3467 Epoch 192/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9578 - g_loss: 1.1828 - kl_divergence: 0.3419 Epoch 193/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9534 - g_loss: 1.1911 - kl_divergence: 0.3455 Epoch 194/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9585 - g_loss: 1.1841 - kl_divergence: 0.3462 Epoch 195/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9560 - g_loss: 1.1895 - kl_divergence: 0.3441 Epoch 196/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9481 - g_loss: 1.1864 - kl_divergence: 0.3454 Epoch 197/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9547 - g_loss: 1.1838 - kl_divergence: 0.3428 Epoch 198/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9475 - g_loss: 1.1966 - kl_divergence: 0.3469 Epoch 199/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9551 - g_loss: 1.2710 - kl_divergence: 0.3457 Epoch 200/200 1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9684 - g_loss: 1.1674 - kl_divergence: 0.3464
<keras.callbacks.History at 0x1f892a588e0>
From the image generated, we can see that Hinge GAN actually performs quite well, with alot of images having some sort of identifiable object in them, while also having the detail that other models are unable to replicate.
Now, we shall make use of FID, as well as visual inspection to help us decide the best model. Since wGAN collapsed, it shall not be used in this comparison.
dcgan_images, _ = dcgan.generate_fake_samples(dcgan, dcgan.generator, n_samples = 1000, latent_dim = 100)
dcgan_fid = calcFID(dcgan_images, num_images = 1000)
print(f"FID for DCGAN is {dcgan_fid}")
32/32 [==============================] - 22s 657ms/step 32/32 [==============================] - 21s 662ms/step FID for DCGAN is 89.84017345710066
cgan_images, _ = cgan.generate_fake_samples(cgan, cgan.generator, n_samples=100, latent_dim = 100)
cgan_fid = calcFID(cgan_images, num_images = 1000)
print(f"FID for cGAN is {cgan_fid}")
32/32 [==============================] - 26s 800ms/step 32/32 [==============================] - 25s 798ms/step FID for cGAN is 67.36210799376948
hinge_gan_images, _ = hinge_gan.generate_fake_samples(hinge_gan, hinge_gan.generator, n_samples = 100, latent_dim = 100)
hinge_gan_fid = calcFID(hinge_gan_images, num_images = 1000)
print(f"FID for Hinge GAN is {hinge_gan_fid}")
32/32 [==============================] - 5s 88ms/step 32/32 [==============================] - 3s 85ms/step FID for Hinge GAN is 70.16999625299437
From the quantitative analysis, we can see that cGAN performs the best, however, on visual inspection, the images from Hinge GAN appear to be better and more detailed, while only performing slightly worse metric-wise. Hence, we shall use it as our final model, and improve it from here.
def batch_images(images, batch_size):
"""Split the images into batches."""
for i in range(0, len(images), batch_size):
yield images[i:i + batch_size]
def display_images_in_grid(images, grid_size, title = None):
"""Display images in a grid."""
fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15))
axs = axs.flatten()
for img, ax in zip(images, axs):
ax.imshow(img)
ax.axis('off')
if title is not None:
plt.suptitle(title, y = 0.92)
plt.show()
# Usage Example
all_images = hinge_gan.generate_fake_samples(hinge_gan, hinge_gan.generator, n_samples=10, latent_dim=100)[0]
all_images = (all_images + 1) / 2.0 # Scale images to [0, 1]
batches = list(batch_images(all_images, 100))
for batch in batches:
display_images_in_grid(batch, 10)
Above are some images generated by Hinge GAN (100 per class). We can see that the images are quite realistic, and well-defined. However, there are images which are blobby/not as well defined. We shall try to optimize the model further to eliminate this.
Now that we have completed the larger steps in model improvement, we shall now take the smaller, final step to allow the model to reach it's peak performance. We shall do this by hypertuning the parameters of the model. For our case, we shall only tune the optimizer, as the loss function is already very well suited for our use case. We shall tune the models using the Adam, SGD, and RMSProp optimizers.
Stochastic Gradient Descent (SGD) is one of the most basic and widely used optimization algorithms in machine learning and deep learning. It's a variant of gradient descent where instead of performing computations on the whole dataset – which can be computationally intensive for large datasets – SGD updates the model's weights using only a single or a few samples at a time. This makes the algorithm much faster and more suitable for large datasets. The formula for SGD is as follows. $$ w_{t+1} = w_t - \eta \cdot \nabla L(w_t, x_i, y_i) $$
\begin{aligned} \text{where} \\ w_{t+1} &\text{ is the updated weight vector at time } t+1, \\ w_t &\text{ is the weight vector at time } t, \\ \eta &\text{ is the learning rate, and} \\ \nabla L(w_t, x_i, y_i) &\text{ is the gradient of the loss function } L \text{ with respect to the weights } w, \text{ evaluated at a randomly chosen data point } (x_i, y_i). \end{aligned}RMSProp, short for Root Mean Square Propagation, is an adaptive learning rate method proposed by Geoffrey Hinton. It addresses some of the limitations of SGD, especially in the context of minimizing functions in very high-dimensional spaces. RMSProp adjusts the learning rate for each weight based on the recent magnitudes of the gradients for that weight. This means that the learning rate is reduced for weights that consistently receive high gradients, which helps in faster convergence especially in situations involving oscillations. The formula for RMSProp is as follows.
Despite Adam and RMSProp being both able to adjust their learning rates, the values which they start at are still quite important to how the model performs. Hence, we shall do 3 permutations for each optimizer, and see how they improve from there.
from tensorflow.keras.optimizers import SGD
from IPython.display import clear_output
optimizer_list = [
Adam(),
Adam(learning_rate=0.0001),
Adam(learning_rate=0.0002),
SGD(learning_rate=0.0005),
SGD(learning_rate=0.05),
SGD(),
RMSprop(learning_rate=0.0005),
RMSprop(learning_rate=0.0002),
RMSprop(learning_rate = 0.0001),
]
name_list = [
'Adam_LR_0.0001',
'Adam_LR_0.0002',
'Adam_LR_0.0005',
'SGD_LR_0.005',
'SGD_LR_0.05',
'SGD_LR_0.01',
'RMSprop_LR_0.0005',
'RMSprop_LR_0.0002',
'RMSprop_LR_0.0001'
]
tuned_Adam_LR_0_0001 = None
tuned_Adam_LR_0_0002 = None
tuned_Adam_LR_0_0005 = None
tuned_SGD_LR_0_005 = None
tuned_SGD_LR_0_05 = None
tuned_SGD_LR_0_01 = None
tuned_RMSprop_LR_0_0005 = None
tuned_RMSprop_LR_0_0002 = None
tuned_RMSprop_LR_0_0001 = None
model_list = [tuned_Adam_LR_0_0001, tuned_Adam_LR_0_0002, tuned_Adam_LR_0_0005, tuned_SGD_LR_0_005, tuned_SGD_LR_0_05, tuned_SGD_LR_0_01, tuned_RMSprop_LR_0_0005, tuned_RMSprop_LR_0_0002, tuned_RMSprop_LR_0_0001]
def model_hypertuner(optimizer, name, model, train_length = 30):
clear_output(wait=True)
print(f"Now attempting to tune {name}")
model = hingeGAN(latent_dim=100)
model.build(())
model.load_weights('output/models/hinge_gan/weights/weights_199.h5')
model.compile(
d_optimizer=optimizer,
g_optimizer=optimizer,
loss_fn=BinaryCrossentropy(),
)
model_callback = CustomCallback(d_losses = model.d_loss_list, g_losses = model.g_loss_list, kl_div=model.kl_div_list, model = model, filepath = f"output/models/hypertune/{name}/")
model.fit(X_train, epochs = train_length, callbacks = [model_callback])
gc.collect()
model_images, _ = model.generate_fake_samples(model, model.generator, n_samples = 100, latent_dim = 100)
model_fid = calcFID(model_images, num_images = 1000)
return model_fid, model_images
output_fid_list = []
output_images_list = []
for optimizer, name, model in zip(optimizer_list, name_list, model_list):
model, history = model_hypertuner(optimizer, name, model, train_length=30)
output_fid_list.append(model)
output_images_list.append(history)
Now attempting to tune RMSprop_LR_0.0001 Epoch 1/30 1562/1562 [==============================] - 58s 36ms/step - d_loss: 0.8175 - g_loss: 1.3469 - kl_divergence: 0.3477 Epoch 2/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7706 - g_loss: 1.3849 - kl_divergence: 0.3518 Epoch 3/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7500 - g_loss: 1.4072 - kl_divergence: 0.3519 Epoch 4/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7411 - g_loss: 1.4317 - kl_divergence: 0.3532 Epoch 5/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7346 - g_loss: 1.4437 - kl_divergence: 0.3529 Epoch 6/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7279 - g_loss: 1.4502 - kl_divergence: 0.3515 Epoch 7/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7164 - g_loss: 1.4729 - kl_divergence: 0.3557 Epoch 8/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7203 - g_loss: 1.4696 - kl_divergence: 0.3653 Epoch 9/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7101 - g_loss: 1.4918 - kl_divergence: 0.3579 Epoch 10/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7082 - g_loss: 1.5021 - kl_divergence: 0.3561 Epoch 11/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7004 - g_loss: 1.5168 - kl_divergence: 0.3538 Epoch 12/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7009 - g_loss: 1.5403 - kl_divergence: 0.3560 Epoch 13/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6940 - g_loss: 1.5389 - kl_divergence: 0.3687 Epoch 14/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6973 - g_loss: 1.5483 - kl_divergence: 0.3571 Epoch 15/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6922 - g_loss: 1.5561 - kl_divergence: 0.3571 Epoch 16/30 1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6926 - g_loss: 1.5567 - kl_divergence: 0.3570 Epoch 17/30 1562/1562 [==============================] - 62s 39ms/step - d_loss: 0.6930 - g_loss: 1.5601 - kl_divergence: 0.3521 Epoch 18/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6888 - g_loss: 1.5766 - kl_divergence: 0.3517 Epoch 19/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6896 - g_loss: 1.5757 - kl_divergence: 0.3523 Epoch 20/30 1562/1562 [==============================] - 60s 38ms/step - d_loss: 0.6815 - g_loss: 1.5937 - kl_divergence: 0.3575 Epoch 21/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6826 - g_loss: 1.6008 - kl_divergence: 0.3604 Epoch 22/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6823 - g_loss: 1.6140 - kl_divergence: 0.3562 Epoch 23/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6810 - g_loss: 1.6086 - kl_divergence: 0.3591 Epoch 24/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6675 - g_loss: 1.6305 - kl_divergence: 0.3548 Epoch 25/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6785 - g_loss: 1.6191 - kl_divergence: 0.3641 Epoch 26/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6818 - g_loss: 1.6346 - kl_divergence: 0.3557 Epoch 27/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6763 - g_loss: 1.6464 - kl_divergence: 0.3532 Epoch 28/30 1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6748 - g_loss: 1.6534 - kl_divergence: 0.3547 Epoch 29/30 1562/1562 [==============================] - 59s 38ms/step - d_loss: 0.6718 - g_loss: 1.6539 - kl_divergence: 0.3544 Epoch 30/30 1562/1562 [==============================] - 60s 38ms/step - d_loss: 0.6735 - g_loss: 1.6584 - kl_divergence: 0.3579 32/32 [==============================] - 4s 88ms/step 32/32 [==============================] - 3s 87ms/step
import pandas as pd
data = {'Name': name_list, 'Value': copy_arr}
df = pd.DataFrame(data)
plt.figure(figsize=(10,6)) # Optional: Set figure size
chart = sns.barplot(x='Value', y='Name', data=df)
chart.set_xlabel('FID Score', fontdict={'size': 12})
plt.tight_layout()
for index, value in enumerate(copy_arr):
plt.text(value, index, round(value,3), color='black', va='center')
plt.title('FID Score based on optimizer') # Optional: Set title
plt.show()
From the above plot, we can see that there are 6 optimizer configurations which produce better results. We shall generate 100 images with each optimizer, and visually inspect the images to see which produces the best results.
target_index = [1,2,3,6,7,8]
for index in target_index:
print(f"Displaying images for {name_list[index]}")
all_images = output_images_list[index]
all_images = (all_images + 1) / 2.0 # Scale images to [0, 1]
np.random.shuffle(all_images)
batches = list(batch_images(all_images[:100], 100))
for batch in batches:
display_images_in_grid(batch, 10, title = f"Generated Images for {name_list[index]}\nFID = {copy_arr[index]:.3f}")
Displaying images for Adam_LR_0.0002
Displaying images for Adam_LR_0.0005
Displaying images for SGD_LR_0.005
Displaying images for RMSprop_LR_0.0005
Displaying images for RMSprop_LR_0.0002
Displaying images for RMSprop_LR_0.0001
From the above, we can see that all the images look similar. However, the images generated by the RMSprop optimizer with a learning rate of 0.001 seem to be most promising, as they have the most detail out of all, but are more mis-shapen. This shows that with further tuning, the model may be able to better learn the missing details and generate better images.
Data Augmentation
Now, we shall attempt to further improvie the model by performing data augmentation before feeding the images to themodel. Data Augmenatation helps to generate and expand our training data from existing samples by augmenting them using random transformations, such as flipping, cropping and rotating. Hence, exposes model to more aspects of data to generate more realistic images.
Sources:
https://www.tensorflow.org/tutorials/images/data_augmentation
https://keras.io/api/layers/preprocessing_layers/image_augmentation/
We are going to create a function to perform Data Augmentation techniques such as RandomFlip, RandomRotation and RandomCrop.
from tensorflow.keras.layers.experimental import preprocessing
def augment_dataset(dataset):
def augment(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
image = tf.image.random_crop(image, size=[image.shape[0] - 4, image.shape[1] - 4, image.shape[2]])
image = tf.image.resize(image, [32, 32])
return image, label
return dataset.map(augment)
X_train_augmented = augment_dataset(X_train.unbatch())
Next, we are going to apply the Data Augmentation on X_train_rescaled.
plt.figure(figsize = (30,30))
for i, (image, _) in enumerate(X_train_augmented.take(100)):
image = (image + 1) / 2.0
plt.subplot(10, 10, i + 1)
plt.imshow(image.numpy())
plt.axis('off')
plt.show()
We are now going to combine the original train dataset, X_train_rescaled, & the data augmentated train dataset, X_train_dataAug.
After doing so, we will have double the amount of train images at 99968 (32 images lost due to half-batching).
X_train_augmented = X_train.unbatch().concatenate(X_train_augmented)
print(f"Length of initial dataset: {X_train.unbatch().cardinality().numpy()} images")
print(f"Length of augmented dataset: {X_train_augmented.cardinality().numpy()} images")
Length of initial dataset: 49984 images Length of augmented dataset: 99968 images
Finally, we are going to feed the augmented data to the model for final training. We shall train two models for 30 epochs each, one with augmneted and one with un-augmented data to see the effects of the augmented data on the model.
optimized_hinge_gan_augmented = hingeGAN(latent_dim=100)
optimized_hinge_gan_augmented.build(())
optimized_hinge_gan_augmented.load_weights('output/models/hypertune/RMSprop_LR_0.0001/weights//weights_29.h5')
optimized_hinge_gan_augmented.compile(
d_optimizer=RMSprop(learning_rate=0.0001),
g_optimizer=RMSprop(learning_rate=0.0001),
loss_fn=Hinge(),
)
hinge_gan_callback = CustomCallback(d_losses = optimized_hinge_gan_augmented.d_loss_list, g_losses = optimized_hinge_gan_augmented.g_loss_list, kl_div=optimized_hinge_gan_augmented.kl_div_list, model = optimized_hinge_gan_augmented, filepath = "output/models/optimized_hinge_gan_augmented/")
optimized_hinge_gan_augmented.fit(X_train_augmented.batch(32, drop_remainder=True), epochs = 30, callbacks = [hinge_gan_callback])
Epoch 1/30 3124/3124 [==============================] - 113s 35ms/step - d_loss: 0.9349 - g_loss: 1.4121 - kl_divergence: 0.2578 Epoch 2/30 3124/3124 [==============================] - 114s 37ms/step - d_loss: 1.0235 - g_loss: 1.3359 - kl_divergence: 0.2421 Epoch 3/30 3124/3124 [==============================] - 111s 36ms/step - d_loss: 1.0222 - g_loss: 1.3069 - kl_divergence: 0.2145 Epoch 4/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0289 - g_loss: 1.2999 - kl_divergence: 0.2144 Epoch 5/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0299 - g_loss: 1.2838 - kl_divergence: 0.2193 Epoch 6/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0394 - g_loss: 1.2809 - kl_divergence: 0.2169 Epoch 7/30 3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0327 - g_loss: 1.2758 - kl_divergence: 0.2150 Epoch 8/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0340 - g_loss: 1.2634 - kl_divergence: 0.2107 Epoch 9/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0312 - g_loss: 1.2546 - kl_divergence: 0.2120 Epoch 10/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0337 - g_loss: 1.2696 - kl_divergence: 0.2089 Epoch 11/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0454 - g_loss: 1.2416 - kl_divergence: 0.2100 Epoch 12/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0356 - g_loss: 1.2528 - kl_divergence: 0.2134 Epoch 13/30 3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0337 - g_loss: 1.2572 - kl_divergence: 0.2063 Epoch 14/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0338 - g_loss: 1.2512 - kl_divergence: 0.2067 Epoch 15/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0362 - g_loss: 1.2508 - kl_divergence: 0.2093 Epoch 16/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0273 - g_loss: 1.2512 - kl_divergence: 0.2046 Epoch 17/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0267 - g_loss: 1.2469 - kl_divergence: 0.2124 Epoch 18/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0314 - g_loss: 1.2428 - kl_divergence: 0.2057 Epoch 19/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0306 - g_loss: 1.2419 - kl_divergence: 0.2076 Epoch 20/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0357 - g_loss: 1.2344 - kl_divergence: 0.2067 Epoch 21/30 3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0303 - g_loss: 1.2390 - kl_divergence: 0.2010 Epoch 22/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0314 - g_loss: 1.2403 - kl_divergence: 0.2154 Epoch 23/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0332 - g_loss: 1.2493 - kl_divergence: 0.2138 Epoch 24/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0334 - g_loss: 1.2367 - kl_divergence: 0.2097 Epoch 25/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0321 - g_loss: 1.2443 - kl_divergence: 0.2077 Epoch 26/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0290 - g_loss: 1.2360 - kl_divergence: 0.2079 Epoch 27/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0315 - g_loss: 1.2444 - kl_divergence: 0.2063 Epoch 28/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0400 - g_loss: 1.2315 - kl_divergence: 0.2025 Epoch 29/30 3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0379 - g_loss: 1.2260 - kl_divergence: 0.2179 Epoch 30/30 3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0300 - g_loss: 1.2352 - kl_divergence: 0.2022
model_images, _ = optimized_hinge_gan_augmented.generate_fake_samples(optimized_hinge_gan_augmented, optimized_hinge_gan_augmented.generator, n_samples = 100, latent_dim = 100)
model_fid = calcFID(model_images, num_images = 1000)
print(f"FID for this model is {model_fid}")
32/32 [==============================] - 4s 80ms/step 32/32 [==============================] - 3s 80ms/step FID for this model is 92.08255961757423
optimized_hinge_gan = hingeGAN(latent_dim=100)
optimized_hinge_gan.build(())
optimized_hinge_gan.load_weights('output/models/hypertune/RMSprop_LR_0.0001/weights//weights_29.h5')
optimized_hinge_gan.compile(
d_optimizer=RMSprop(learning_rate=0.0001),
g_optimizer=RMSprop(learning_rate=0.0001),
loss_fn=Hinge(),
)
hinge_gan_callback = CustomCallback(d_losses = optimized_hinge_gan.d_loss_list, g_losses = optimized_hinge_gan.g_loss_list, kl_div=optimized_hinge_gan.kl_div_list, model = optimized_hinge_gan, filepath = "output/models/optimized_hinge_gan/")
optimized_hinge_gan.fit(X_train, epochs = 30, callbacks = [hinge_gan_callback])
Epoch 1/30 1562/1562 [==============================] - 58s 36ms/step - d_loss: 0.6968 - g_loss: 1.6341 - kl_divergence: 0.3676 Epoch 2/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6986 - g_loss: 1.6358 - kl_divergence: 0.3562 Epoch 3/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6957 - g_loss: 1.6435 - kl_divergence: 0.3568 Epoch 4/30 1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6983 - g_loss: 1.6359 - kl_divergence: 0.3534 Epoch 5/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6955 - g_loss: 1.6453 - kl_divergence: 0.3548 Epoch 6/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6957 - g_loss: 1.6457 - kl_divergence: 0.3520 Epoch 7/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6908 - g_loss: 1.6564 - kl_divergence: 0.3573 Epoch 8/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6943 - g_loss: 1.6470 - kl_divergence: 0.3551 Epoch 9/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6864 - g_loss: 1.6681 - kl_divergence: 0.3562 Epoch 10/30 1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6934 - g_loss: 1.6683 - kl_divergence: 0.3561 Epoch 11/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6923 - g_loss: 1.6623 - kl_divergence: 0.3554 Epoch 12/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6931 - g_loss: 1.6750 - kl_divergence: 0.3542 Epoch 13/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6851 - g_loss: 1.6814 - kl_divergence: 0.3563 Epoch 14/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6882 - g_loss: 1.6800 - kl_divergence: 0.3591 Epoch 15/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6868 - g_loss: 1.6793 - kl_divergence: 0.3587 Epoch 16/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6901 - g_loss: 1.6842 - kl_divergence: 0.3555 Epoch 17/30 1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.6888 - g_loss: 1.6784 - kl_divergence: 0.3547 Epoch 18/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6900 - g_loss: 1.6878 - kl_divergence: 0.3528 Epoch 19/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6867 - g_loss: 1.6957 - kl_divergence: 0.3564 Epoch 20/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6809 - g_loss: 1.6891 - kl_divergence: 0.3566 Epoch 21/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6833 - g_loss: 1.7030 - kl_divergence: 0.3561 Epoch 22/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6760 - g_loss: 1.7161 - kl_divergence: 0.3544 Epoch 23/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6781 - g_loss: 1.7147 - kl_divergence: 0.3546 Epoch 24/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6727 - g_loss: 1.7177 - kl_divergence: 0.3550 Epoch 25/30 1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.6823 - g_loss: 1.7063 - kl_divergence: 0.3533 Epoch 26/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6836 - g_loss: 1.7184 - kl_divergence: 0.3557 Epoch 27/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6811 - g_loss: 1.7192 - kl_divergence: 0.3527 Epoch 28/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6753 - g_loss: 1.7416 - kl_divergence: 0.3564 Epoch 29/30 1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6737 - g_loss: 1.7314 - kl_divergence: 0.3550 Epoch 30/30 1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6846 - g_loss: 1.7264 - kl_divergence: 0.3556
<keras.callbacks.History at 0x1a4367166a0>
model_images, _ = optimized_hinge_gan.generate_fake_samples(optimized_hinge_gan, optimized_hinge_gan.generator, n_samples = 100, latent_dim = 100)
model_fid = calcFID(model_images, num_images = 1000)
print(f"FID for this model is {model_fid}")
32/32 [==============================] - 4s 80ms/step 32/32 [==============================] - 3s 80ms/step FID for this model is 69.90489118078709
| Augmented Results |
Unaugmented Results |
|---|---|
|
|
Comparing the models in terms of metrics and visualization, we can see that the unaugmented model still performs better than the augmneted model. Not only does the unaugmented model hold detail in it's images better than that of the augmented one, it also performs better metric-wise, with a FID of 69 compared to the augmented model's score of 92. Hence, we shall seelct the unaugmented model as our final model, and generate the images required.
The images will also be saved in the "/generated" directory.
The final model's weights will also be saved as best_weights.h5, and can be found in the same directory as this file.
model_images, _ = optimized_hinge_gan.generate_fake_samples(optimized_hinge_gan, optimized_hinge_gan.generator, n_samples = 100, latent_dim = 100)
model_images = (model_images + 1) / 2.0 # Scale images to [0, 1]
batches = list(batch_images(model_images, 100))
for i, batch in enumerate(batches):
display_images_in_grid(batch, 10, title = f"Generated {class_names[i]} Images for Optimized Hinge GAN\n")
To conclude, this task was a fun and intruiging one for both of us, and allowed us deeper insight into the considerations and limitations that have to be taken into account during the development of such AI tools. We have also learned deepened our understanding, furthered our learning, and solidified our foundation in the field of deep learning, especially on the topic of Generative Adversarial Networks. Despite this assignment being a interesting, there are some ideas and things that we would have loved to try, such as SA-GAN models, or using more ways to stablilize/improve models such as gaussian weights, or even upscaling the data to generate images with higher fidelity.